Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert to use unit test name at top level key #8966

Merged
merged 12 commits into from
Nov 3, 2023
9 changes: 2 additions & 7 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,16 +830,11 @@ class UnitTestOverrides(dbtClassMixin):


@dataclass
class UnparsedUnitTestDefinition(dbtClassMixin):
class UnparsedUnitTest(dbtClassMixin):
name: str
model: str # name of the model being unit tested
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture
description: str = ""
overrides: Optional[UnitTestOverrides] = None
config: Dict[str, Any] = field(default_factory=dict)


@dataclass
class UnparsedUnitTestSuite(dbtClassMixin):
model: str # name of the model being unit tested
tests: Sequence[UnparsedUnitTestDefinition]
16 changes: 7 additions & 9 deletions core/dbt/parser/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@
self.saved_manifest.files.pop(file_id)

# For each key in a schema file dictionary, process the changed, deleted, and added
# elemnts for the key lists
# elements for the key lists
def handle_schema_file_changes(self, schema_file, saved_yaml_dict, new_yaml_dict):
# loop through comparing previous dict_from_yaml with current dict_from_yaml
# Need to do the deleted/added/changed thing, just like the files lists
Expand Down Expand Up @@ -711,7 +711,6 @@
# Take a "section" of the schema file yaml dictionary from saved and new schema files
# and determine which parts have changed
def get_diff_for(self, key, saved_yaml_dict, new_yaml_dict):
dict_name = "model" if key == "unit_tests" else "name"
if key in saved_yaml_dict or key in new_yaml_dict:
saved_elements = saved_yaml_dict[key] if key in saved_yaml_dict else []
new_elements = new_yaml_dict[key] if key in new_yaml_dict else []
Expand All @@ -722,9 +721,9 @@
new_elements_by_name = {}
# sources have two part names?
for element in saved_elements:
saved_elements_by_name[element[dict_name]] = element
saved_elements_by_name[element["name"]] = element
for element in new_elements:
new_elements_by_name[element[dict_name]] = element
new_elements_by_name[element["name"]] = element

# now determine which elements, by name, are added, deleted or changed
saved_element_names = set(saved_elements_by_name.keys())
Expand Down Expand Up @@ -754,7 +753,6 @@
# flag indicates that we're processing a schema file, so if a matching
# patch has already been scheduled, replace it.
def merge_patch(self, schema_file, key, patch, new_patch=False):
elem_name = "model" if key == "unit_tests" else "name"
if schema_file.pp_dict is None:
schema_file.pp_dict = {}
pp_dict = schema_file.pp_dict
Expand All @@ -764,7 +762,7 @@
# check that this patch hasn't already been saved
found_elem = None
for elem in pp_dict[key]:
if elem["name"] == patch[elem_name]:
if elem["name"] == patch["name"]:
found_elem = elem
if not found_elem:
pp_dict[key].append(patch)
Expand All @@ -773,7 +771,7 @@
pp_dict[key].remove(found_elem)
pp_dict[key].append(patch)

schema_file.delete_from_env_vars(key, patch[elem_name])
schema_file.delete_from_env_vars(key, patch["name"])
self.add_to_pp_files(schema_file)

# For model, seed, snapshot, analysis schema dictionary keys,
Expand Down Expand Up @@ -942,12 +940,12 @@
self.delete_disabled(unique_id, schema_file.file_id)

def delete_schema_unit_test(self, schema_file, unit_test_dict):
unit_test_model_name = unit_test_dict["model"]
unit_test_name = unit_test_dict["name"]

Check warning on line 943 in core/dbt/parser/partial.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/partial.py#L943

Added line #L943 was not covered by tests
unit_tests = schema_file.unit_tests.copy()
for unique_id in unit_tests:
if unique_id in self.saved_manifest.unit_tests:
unit_test = self.saved_manifest.unit_tests[unique_id]
if unit_test.model == unit_test_model_name:
if unit_test.name == unit_test_name:

Check warning on line 948 in core/dbt/parser/partial.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/partial.py#L948

Added line #L948 was not covered by tests
self.saved_manifest.unit_tests.pop(unique_id)
schema_file.unit_tests.remove(unique_id)
# No disabled unit tests yet
Expand Down
76 changes: 38 additions & 38 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
DependsOn,
UnitTestConfig,
)
from dbt.contracts.graph.unparsed import UnparsedUnitTestSuite
from dbt.contracts.graph.unparsed import UnparsedUnitTest
from dbt.exceptions import ParsingError, InvalidUnitTestGivenInput
from dbt.graph import UniqueId
from dbt.node_types import NodeType
Expand Down Expand Up @@ -199,50 +199,50 @@

def parse(self) -> ParseResult:
for data in self.get_key_dicts():
unit_test_suite = self._get_unit_test_suite(data)
model_name_split = unit_test_suite.model.split()
tested_model_node = self._find_tested_model_node(unit_test_suite)

for test in unit_test_suite.tests:
unit_test_case_unique_id = f"{NodeType.Unit}.{self.project.project_name}.{unit_test_suite.model}.{test.name}"
unit_test_fqn = [self.project.project_name] + model_name_split + [test.name]
unit_test_config = self._build_unit_test_config(unit_test_fqn, test.config)

# Check that format and type of rows matches for each given input
for input in test.given:
input.validate_fixture("input", test.name)
test.expect.validate_fixture("expected", test.name)

unit_test_definition = UnitTestDefinition(
name=test.name,
model=unit_test_suite.model,
resource_type=NodeType.Unit,
package_name=self.project.project_name,
path=self.yaml.path.relative_path,
original_file_path=self.yaml.path.original_file_path,
unique_id=unit_test_case_unique_id,
given=test.given,
expect=test.expect,
description=test.description,
overrides=test.overrides,
depends_on=DependsOn(nodes=[tested_model_node.unique_id]),
fqn=unit_test_fqn,
config=unit_test_config,
)
self.manifest.add_unit_test(self.yaml.file, unit_test_definition)
unit_test = self._get_unit_test(data)
model_name_split = unit_test.model.split()
tested_model_node = self._find_tested_model_node(unit_test)
unit_test_case_unique_id = (
f"{NodeType.Unit}.{self.project.project_name}.{unit_test.model}.{unit_test.name}"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we were hard coding unit here? It doesn't match how we assign the unique ids for other node types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's an explicit reason. What would be a more conventional way to build the unique_id here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using {NodeType.Unit} instead of a hardcoded unit is more conventional. I went ahead and made the change because it didn't seem intentional but wanted to check in case.

)
unit_test_fqn = [self.project.project_name] + model_name_split + [unit_test.name]
unit_test_config = self._build_unit_test_config(unit_test_fqn, unit_test.config)

# Check that format and type of rows matches for each given input
for input in unit_test.given:
input.validate_fixture("input", unit_test.name)

Check warning on line 213 in core/dbt/parser/unit_tests.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/unit_tests.py#L213

Added line #L213 was not covered by tests
unit_test.expect.validate_fixture("expected", unit_test.name)

unit_test_definition = UnitTestDefinition(
name=unit_test.name,
model=unit_test.model,
resource_type=NodeType.Unit,
package_name=self.project.project_name,
path=self.yaml.path.relative_path,
original_file_path=self.yaml.path.original_file_path,
unique_id=unit_test_case_unique_id,
given=unit_test.given,
expect=unit_test.expect,
description=unit_test.description,
overrides=unit_test.overrides,
depends_on=DependsOn(nodes=[tested_model_node.unique_id]),
fqn=unit_test_fqn,
config=unit_test_config,
)
self.manifest.add_unit_test(self.yaml.file, unit_test_definition)

return ParseResult()

def _get_unit_test_suite(self, data: Dict[str, Any]) -> UnparsedUnitTestSuite:
def _get_unit_test(self, data: Dict[str, Any]) -> UnparsedUnitTest:
try:
UnparsedUnitTestSuite.validate(data)
return UnparsedUnitTestSuite.from_dict(data)
UnparsedUnitTest.validate(data)
return UnparsedUnitTest.from_dict(data)
except (ValidationError, JSONValidationError) as exc:
raise YamlParseDictError(self.yaml.path, self.key, data, exc)

def _find_tested_model_node(self, unit_test_suite: UnparsedUnitTestSuite) -> ModelNode:
def _find_tested_model_node(self, unit_test: UnparsedUnitTest) -> ModelNode:
package_name = self.project.project_name
model_name_split = unit_test_suite.model.split()
model_name_split = unit_test.model.split()
model_name = model_name_split[0]
model_version = model_name_split[1] if len(model_name_split) == 2 else None

Expand All @@ -251,7 +251,7 @@
)
if not tested_node:
raise ParsingError(
f"Unable to find model '{package_name}.{unit_test_suite.model}' for unit tests in {self.yaml.path.original_file_path}"
f"Unable to find model '{package_name}.{unit_test.model}' for unit tests in {self.yaml.path.original_file_path}"
)

return tested_node
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/task/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def before_execute(self):
self.print_start_line()

def execute_unit_test(self, node: UnitTestNode, manifest: Manifest) -> UnitTestResultData:
# generate_runtime_unit_test_context not strictly needed - this is to run the 'unit' materialization, not compile the node.compield_code
# generate_runtime_unit_test_context not strictly needed - this is to run the 'unit'
# materialization, not compile the node.compiled_code
context = generate_runtime_model_context(node, self.config, manifest)

materialization_macro = manifest.find_materialization_macro_by_name(
Expand Down
Loading