Skip to content

Commit

Permalink
Further improvements to dbt sort order linter (#647)
Browse files Browse the repository at this point in the history
* Add a trailing slash to venv/ exclusion in pre-commit config

* Add a docstring to the check_sort_dbt_yaml_files.py script

* Sort unit_tests in addition to data_tests in check_sort_dbt_yaml_files
  • Loading branch information
jeancochrane authored Nov 18, 2024
1 parent 6a2a3f9 commit d87dd94
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ repos:
language: system
types_or: [yaml, markdown]
files: ^dbt/
exclude: venv
exclude: venv/
113 changes: 72 additions & 41 deletions dbt/scripts/check_sort_dbt_yaml_files.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# Script that acts as a linter to check that our dbt assets are ordered
# alphabetically. Attributes that we check for ordering include:
#
# * `columns`, `data_tests`, and `unit_tests` in `schema.yml` files
# * Headings in all docs Markdown files, with special cases for `columns.md`
# and `shared_columns.md` files that have specific sub-headings
#
# When positional arguments are present, the script interprets them as a list of
# filepaths to lint. When no positional arguments are present, the script will
# list all files recursively under the current working directory and lint
# everything it finds with a filename or file extension that matches the assets
# listed above.
#
# We primarily run this script via pre-commit, which automatically passes in
# names of files that have changed. The default behavior when no arguments are
# present exists to support running this script outside the context of
# pre-commit.
import os
import re
import sys
Expand Down Expand Up @@ -130,9 +147,10 @@ def check_columns_in_yaml(
return unsorted_files_dict, []


def check_data_tests(file_path):
def check_tests(file_path):
"""
Check if the 'data_tests' sections in a YAML file are sorted.
Check if the 'data_tests' and 'unit_tests' sections in a YAML file are
sorted.
Args:
file_path (str): The path to the YAML file to check.
Expand All @@ -146,11 +164,12 @@ def check_data_tests(file_path):
except yaml.YAMLError as error:
return [error], [file_path]

def check_data_tests_in_yaml(
def check_tests_in_yaml(
data, file_path, unsorted_files_dict, parent_key=None
):
"""
Recursively check the 'data_tests' sections in a YAML structure for sorting.
Recursively check the 'data_tests' and 'unit_tests' sections in a YAML
structure for sorting.
Args:
data (dict or list): The YAML data to check.
Expand All @@ -160,46 +179,60 @@ def check_data_tests_in_yaml(
"""
if isinstance(data, dict):
for key, value in data.items():
if key == "data_tests" and isinstance(value, list):
data_test_names = []
if key in ("data_tests", "unit_tests") and isinstance(
value, list
):
test_names = []
for test in value:
if isinstance(test, dict):
for test_type, test_details in test.items():
if (
isinstance(test_details, dict)
and "name" in test_details
):
data_test_names.append(
(test_type, test_details["name"], test)
)
# Unit tests have 'name' as a top-level attribute,
# while data tests nest it inside the value keyed to
# the name of the generic test. Hence if 'name' is
# present, it means we're dealing with a unit test,
# and otherwise we need to dig further to get the
# name
if "name" in test:
test_names.append(test["name"])
else:
for test_type, test_details in test.items():
if (
isinstance(test_details, dict)
and "name" in test_details
):
test_names.append(test_details["name"])

sorted_tests = sorted(
data_test_names,
key=lambda x: alphanumeric_key(normalize_string(x[1])),
test_names,
key=lambda x: alphanumeric_key(normalize_string(x)),
)
if data_test_names != sorted_tests:
if test_names != sorted_tests:
print(f"In file: {file_path}")
print(f"Key above 'data_tests': {parent_key}")
print("Data tests in this group are not sorted:")
for i, (_, name, _) in enumerate(data_test_names):
if name != sorted_tests[i][1]:
print(f"---> {name}")
# Unit tests are top-level attributes, so there will be
# no parent key in their case
if parent_key:
print(f"Top level: {parent_key}")
print(f"{key} in this group are not sorted:")
for test_name, sorted_test_name in zip(
test_names, sorted_tests
):
if test_name != sorted_test_name:
print(f"---> {test_name}")
else:
print(f"- {name}")
print(f"- {test_name}")
print("-" * 40)
unsorted_files_dict[file_path] += 1
else:
check_data_tests_in_yaml(
check_tests_in_yaml(
value, file_path, unsorted_files_dict, key
)
elif isinstance(data, list):
for item in data:
check_data_tests_in_yaml(
check_tests_in_yaml(
item, file_path, unsorted_files_dict, parent_key
)

unsorted_files_dict = defaultdict(int)
check_data_tests_in_yaml(data, file_path, unsorted_files_dict)
check_tests_in_yaml(data, file_path, unsorted_files_dict)
return unsorted_files_dict, []


Expand Down Expand Up @@ -368,7 +401,7 @@ def check_files(file_paths: list[str]):
tuple: Results of unsorted files and errors for different checks.
"""
unsorted_columns_files: dict[str, int] = defaultdict(int)
unsorted_data_tests_files: dict[str, int] = defaultdict(int)
unsorted_tests_files: dict[str, int] = defaultdict(int)
error_files = []
unsorted_md_files = []
unsorted_columns_md_files = []
Expand All @@ -382,9 +415,9 @@ def check_files(file_paths: list[str]):
unsorted_columns, errors = check_columns(file_path)
for key, value in unsorted_columns.items():
unsorted_columns_files[key] += value
unsorted_data_tests, errors = check_data_tests(file_path)
for key, value in unsorted_data_tests.items():
unsorted_data_tests_files[key] += value
unsorted_tests, errors = check_tests(file_path)
for key, value in unsorted_tests.items():
unsorted_tests_files[key] += value
if errors:
error_files.extend(errors)
elif os.path.basename(file_path) == "docs.md":
Expand All @@ -406,7 +439,7 @@ def check_files(file_paths: list[str]):

return (
unsorted_columns_files,
unsorted_data_tests_files,
unsorted_tests_files,
error_files,
unsorted_md_files,
unsorted_columns_md_files,
Expand All @@ -419,7 +452,7 @@ def check_files(file_paths: list[str]):
if args:
(
unsorted_columns_files,
unsorted_data_tests_files,
unsorted_tests_files,
error_files,
unsorted_md_files,
unsorted_columns_md_files,
Expand All @@ -428,7 +461,7 @@ def check_files(file_paths: list[str]):
else:
(
unsorted_columns_files,
unsorted_data_tests_files,
unsorted_tests_files,
error_files,
unsorted_md_files,
unsorted_columns_md_files,
Expand All @@ -440,9 +473,9 @@ def check_files(file_paths: list[str]):
for file, count in unsorted_columns_files.items():
print(f"{file} ({count})")

if unsorted_data_tests_files:
print("\nThe following files have unsorted data tests:")
for file, count in unsorted_data_tests_files.items():
if unsorted_tests_files:
print("\nThe following files have unsorted tests:")
for file, count in unsorted_tests_files.items():
print(f"{file} ({count})")

if unsorted_md_files:
Expand All @@ -465,16 +498,14 @@ def check_files(file_paths: list[str]):
print("\n")
if (
unsorted_columns_files
or unsorted_data_tests_files
or unsorted_tests_files
or error_files
or unsorted_md_files
or unsorted_columns_md_files
or unsorted_shared_columns_md_files
):
raise ValueError(
"Column name, data test, or heading sort order check ran into failures, see logs above"
"Column name, test, or heading sort order check ran into failures, see logs above"
)

print(
"All files have sorted columns, data tests, headings, and no errors."
)
print("All files have sorted columns, tests, headings, and no errors.")

0 comments on commit d87dd94

Please sign in to comment.