Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support --empty flag for schema-only dry runs #8971

Merged
merged 16 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231116-234049.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support --empty flag for schema-only dry runs
time: 2023-11-16T23:40:49.96651-05:00
custom:
Author: michelleark
Issue: "8971"
18 changes: 13 additions & 5 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class BaseRelation(FakeAPIObject, Hashable):
include_policy: Policy = field(default_factory=lambda: Policy())
quote_policy: Policy = field(default_factory=lambda: Policy())
dbt_created: bool = False
limit: Optional[int] = None

# register relation types that can be renamed for the purpose of replacing relations using stages and backups
# adding a relation type here also requires defining the associated rename macro
Expand Down Expand Up @@ -194,6 +195,15 @@ def render(self) -> str:
# if there is nothing set, this will return the empty string.
return ".".join(part for _, part in self._render_iterator() if part is not None)

def render_limited(self) -> str:
rendered = self.render()
if self.limit is None:
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved
return rendered
elif self.limit == 0:
return f"(select * from {rendered} where false limit 0) _dbt_limit_subq"
else:
return f"(select * from {rendered} limit {self.limit}) _dbt_limit_subq"

def quoted(self, identifier):
return "{quote_char}{identifier}{quote_char}".format(
quote_char=self.quote_character,
Expand Down Expand Up @@ -227,13 +237,11 @@ def create_ephemeral_from_node(
cls: Type[Self],
config: HasQuoting,
node: ManifestNode,
**kwargs: Any,
) -> Self:
# Note that ephemeral models are based on the name.
identifier = cls.add_ephemeral_prefix(node.name)
return cls.create(
type=cls.CTE,
identifier=identifier,
).quote(identifier=False)
return cls.create(type=cls.CTE, identifier=identifier, **kwargs).quote(identifier=False)
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def create_from_node(
Expand Down Expand Up @@ -313,7 +321,7 @@ def __hash__(self) -> int:
return hash(self.render())

def __str__(self) -> str:
return self.render()
return self.render() if self.limit is None else self.render_limited()

@property
def database(self) -> Optional[str]:
Expand Down
2 changes: 2 additions & 0 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def docs_serve(ctx, **kwargs):
@p.profile
@p.profiles_dir
@p.project_dir
@p.empty
@p.select
@p.selector
@p.inline
Expand Down Expand Up @@ -599,6 +600,7 @@ def parse(ctx, **kwargs):
@p.profile
@p.profiles_dir
@p.project_dir
@p.empty
@p.select
@p.selector
@p.state
Expand Down
6 changes: 6 additions & 0 deletions core/dbt/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@
is_flag=True,
)

empty = click.option(
"--empty",
envvar="DBT_EMPTY",
help="If specified, limit input refs and sources to zero rows.",
is_flag=True,
)

enable_legacy_logger = click.option(
"--enable-legacy-logger/--no-enable-legacy-logger",
Expand Down
12 changes: 9 additions & 3 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ def current_project(self):
def Relation(self):
return self.db_wrapper.Relation

@property
def resolve_limit(self) -> Optional[int]:
return 0 if getattr(self.config.args, "EMPTY", False) else None

@abc.abstractmethod
def __call__(self, *args: str) -> Union[str, RelationProxy, MetricReference]:
pass
Expand Down Expand Up @@ -531,9 +535,11 @@ def resolve(
def create_relation(self, target_model: ManifestNode) -> RelationProxy:
if target_model.is_ephemeral_model:
self.model.set_cte(target_model.unique_id, None)
return self.Relation.create_ephemeral_from_node(self.config, target_model)
return self.Relation.create_ephemeral_from_node(
self.config, target_model, limit=self.resolve_limit
)
else:
return self.Relation.create_from(self.config, target_model)
return self.Relation.create_from(self.config, target_model, limit=self.resolve_limit)

def validate(
self,
Expand Down Expand Up @@ -590,7 +596,7 @@ def resolve(self, source_name: str, table_name: str):
target_kind="source",
disabled=(isinstance(target_source, Disabled)),
)
return self.Relation.create_from_source(target_source)
return self.Relation.create_from_source(target_source, limit=self.resolve_limit)


# metric` implementations
Expand Down
75 changes: 75 additions & 0 deletions tests/adapter/dbt/tests/adapter/empty/test_empty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
from dbt.tests.util import run_dbt, relation_from_name


model_input_sql = """
select 1 as id
"""

ephemeral_model_input_sql = """
{{ config(materialized='ephemeral') }}
select 2 as id
"""

raw_source_csv = """id
3
"""


model_sql = """
select *
from {{ ref('model_input') }}
union all
select *
from {{ ref('ephemeral_model_input') }}
union all
select *
from {{ source('seed_sources', 'raw_source') }}
"""


schema_sources_yml = """
sources:
- name: seed_sources
schema: "{{ target.schema }}"
tables:
- name: raw_source
"""


class BaseTestEmpty:
@pytest.fixture(scope="class")
def seeds(self):
return {
"raw_source.csv": raw_source_csv,
}

@pytest.fixture(scope="class")
def models(self):
return {
"model_input.sql": model_input_sql,
"ephemeral_model_input.sql": ephemeral_model_input_sql,
"model.sql": model_sql,
"sources.yml": schema_sources_yml,
}

def assert_row_count(self, project, relation_name: str, expected_row_count: int):
relation = relation_from_name(project.adapter, relation_name)
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
assert result[0] == expected_row_count

def test_run_with_empty(self, project):
# create source from seed
run_dbt(["seed"])

# run without empty - 3 expected rows in output - 1 from each input
run_dbt(["run"])
self.assert_row_count(project, "model", 3)

# run with empty - 0 expected rows in output
run_dbt(["run", "--empty"])
self.assert_row_count(project, "model", 0)


class TestEmpty(BaseTestEmpty):
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved
pass
104 changes: 104 additions & 0 deletions tests/unit/test_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pytest
from unittest import mock

from dbt.adapters.base import BaseRelation
from dbt.context.providers import BaseResolver, RuntimeRefResolver, RuntimeSourceResolver
from dbt.contracts.graph.unparsed import Quoting


class TestBaseResolver:
class ResolverSubclass(BaseResolver):
def __call__(self, *args: str):
pass

@pytest.fixture
def resolver(self):
return self.ResolverSubclass(
db_wrapper=mock.Mock(),
model=mock.Mock(),
config=mock.Mock(),
manifest=mock.Mock(),
)

@pytest.mark.parametrize(
"empty,expected_resolve_limit",
[(False, None), (True, 0)],
)
def test_resolve_limit(self, resolver, empty, expected_resolve_limit):
resolver.config.args.EMPTY = empty

assert resolver.resolve_limit == expected_resolve_limit


class TestRuntimeRefResolver:
@pytest.fixture
def resolver(self):
mock_db_wrapper = mock.Mock()
mock_db_wrapper.Relation = BaseRelation

return RuntimeRefResolver(
db_wrapper=mock_db_wrapper,
model=mock.Mock(),
config=mock.Mock(),
manifest=mock.Mock(),
)

@pytest.mark.parametrize(
"empty,is_ephemeral_model,expected_limit",
[
(False, False, None),
(True, False, 0),
(False, True, None),
(True, True, 0),
],
)
def test_create_relation_with_empty(self, resolver, empty, is_ephemeral_model, expected_limit):
# setup resolver and input node
resolver.config.args.EMPTY = empty
mock_node = mock.Mock()
mock_node.database = "test"
mock_node.schema = "test"
mock_node.identifier = "test"
mock_node.alias = "test"
mock_node.is_ephemeral_model = is_ephemeral_model

# create limited relation
with mock.patch("dbt.adapters.base.relation.ParsedNode", new=mock.Mock):
relation = resolver.create_relation(mock_node)
assert relation.limit == expected_limit


class TestRuntimeSourceResolver:
@pytest.fixture
def resolver(self):
mock_db_wrapper = mock.Mock()
mock_db_wrapper.Relation = BaseRelation

return RuntimeSourceResolver(
db_wrapper=mock_db_wrapper,
model=mock.Mock(),
config=mock.Mock(),
manifest=mock.Mock(),
)

@pytest.mark.parametrize(
"empty,expected_limit",
[
(False, None),
(True, 0),
],
)
def test_create_relation_with_empty(self, resolver, empty, expected_limit):
# setup resolver and input source
resolver.config.args.EMPTY = empty

mock_source = mock.Mock()
mock_source.database = "test"
mock_source.schema = "test"
mock_source.identifier = "test"
mock_source.quoting = Quoting()
resolver.manifest.resolve_source.return_value = mock_source

# create limited relation
relation = resolver.resolve("test", "test")
assert relation.limit == expected_limit
26 changes: 26 additions & 0 deletions tests/unit/test_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,29 @@ def test_can_be_replaced(relation_type, result):
def test_can_be_replaced_default():
my_relation = BaseRelation.create(type=RelationType.View)
assert my_relation.can_be_replaced is False


@pytest.mark.parametrize(
"limit,expected_result",
[
(None, '"test_database"."test_schema"."test_identifier"'),
(
0,
'(select * from "test_database"."test_schema"."test_identifier" where false limit 0) _dbt_limit_subq',
),
(
1,
'(select * from "test_database"."test_schema"."test_identifier" limit 1) _dbt_limit_subq',
),
],
)
def test_render_limited(limit, expected_result):
my_relation = BaseRelation.create(
database="test_database",
schema="test_schema",
identifier="test_identifier",
limit=limit,
)
actual_result = my_relation.render_limited()
assert actual_result == expected_result
assert str(my_relation) == expected_result
Loading