Skip to content

Commit

Permalink
Throw ParsingError for invalid PK definition on constraints (#9700)
Browse files Browse the repository at this point in the history
* audit pks on constraints

* pr feedback
  • Loading branch information
emmyoop authored Feb 29, 2024
1 parent f48a927 commit ce10240
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 2 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Fixes-20240228-135928.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Fixes
body: Throw a ParsingError if a primary key constraint is defined on multiple columns
or at both the column and model level.
time: 2024-02-28T13:59:28.728561-06:00
custom:
Author: emmyoop
Issue: "9581"
37 changes: 35 additions & 2 deletions core/dbt/parser/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,10 +897,43 @@ def patch_constraints(self, node, constraints):
f"Type must be one of {[ct.value for ct in ConstraintType]}"
)

node.constraints = [ModelLevelConstraint.from_dict(c) for c in constraints]
self._validate_pk_constraints(node, constraints)
node.constraints = [ModelLevelConstraint.from_dict(c) for c in constraints]

def _validate_constraint_prerequisites(self, model_node: ModelNode):
def _validate_pk_constraints(self, model_node: ModelNode, constraints: List[Dict[str, Any]]):
errors = []
# check for primary key constraints defined at the column level
pk_col: List[str] = []
for col in model_node.columns.values():
for constraint in col.constraints:
if constraint.type == ConstraintType.primary_key:
pk_col.append(col.name)

if len(pk_col) > 1:
errors.append(
f"Found {len(pk_col)} columns ({pk_col}) with primary key constraints defined. "
"Primary keys for multiple columns must be defined as a model level constraint."
)

if len(pk_col) > 0 and (
any(
constraint.type == ConstraintType.primary_key
for constraint in model_node.constraints
)
or any(constraint["type"] == ConstraintType.primary_key for constraint in constraints)
):
errors.append(
"Primary key constraints defined at the model level and the columns level. "
"Primary keys can be defined at the model level or the column level, not both."
)

if errors:
raise ParsingError(
f"Primary key constraint error: ({model_node.original_file_path})\n"
+ "\n".join(errors)
)

def _validate_constraint_prerequisites(self, model_node: ModelNode):
column_warn_unsupported = [
constraint.warn_unsupported
for column in model_node.columns.values()
Expand Down
85 changes: 85 additions & 0 deletions tests/functional/configs/test_contract_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,59 @@ def model(dbt, _):
data_type: date
"""

model_pk_model_column_schema_yml = """
models:
- name: my_model
config:
contract:
enforced: true
constraints:
- type: primary_key
columns: [id]
columns:
- name: id
data_type: integer
description: hello
constraints:
- type: not_null
- type: primary_key
- type: check
expression: (id > 0)
data_tests:
- unique
- name: color
data_type: string
- name: date_day
data_type: date
"""

model_pk_mult_column_schema_yml = """
models:
- name: my_model
config:
contract:
enforced: true
columns:
- name: id
quote: true
data_type: integer
description: hello
constraints:
- type: not_null
- type: primary_key
- type: check
expression: (id > 0)
data_tests:
- unique
- name: color
data_type: string
constraints:
- type: not_null
- type: primary_key
- name: date_day
data_type: date
"""

model_schema_alias_types_false_yml = """
models:
- name: my_model
Expand Down Expand Up @@ -514,3 +567,35 @@ def test__missing_column_contract_error(self, project):
"This model has an enforced contract, and its 'columns' specification is missing"
)
assert expected_error in results[0].message


# test primary key defined across model and column level constraints, expect error
class TestPrimaryKeysModelAndColumnLevelConstraints:
@pytest.fixture(scope="class")
def models(self):
return {
"constraints_schema.yml": model_pk_model_column_schema_yml,
"my_model.sql": my_model_sql,
}

def test_model_column_pk_error(self, project):
expected_error = "Primary key constraints defined at the model level and the columns level"
with pytest.raises(ParsingError) as exc_info:
run_dbt(["run"])
assert expected_error in str(exc_info.value)


# test primary key defined across multiple columns, expect error
class TestPrimaryKeysMultipleColumns:
@pytest.fixture(scope="class")
def models(self):
return {
"constraints_schema.yml": model_pk_mult_column_schema_yml,
"my_model.sql": my_model_sql,
}

def test_pk_multiple_columns(self, project):
expected_error = "Found 2 columns (['id', 'color']) with primary key constraints defined"
with pytest.raises(ParsingError) as exc_info:
run_dbt(["run"])
assert expected_error in str(exc_info.value)
150 changes: 150 additions & 0 deletions tests/functional/configs/test_versioned_model_constraint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from dbt.tests.util import run_dbt, rm_file, write_file, get_manifest
from dbt.exceptions import ParsingError


schema_yml = """
Expand All @@ -25,6 +26,10 @@
select 1 as id, 'alice' as user_name
"""

foo_v2_sql = """
select 1 as id, 'alice' as user_name, 2 as another_pk
"""

versioned_schema_yml = """
models:
- name: foo
Expand All @@ -47,6 +52,69 @@
- v: 1
"""

versioned_pk_model_column_schema_yml = """
models:
- name: foo
latest_version: 2
config:
materialized: table
contract:
enforced: true
constraints:
- type: primary_key
columns: [id]
columns:
- name: id
data_type: int
constraints:
- type: not_null
- name: user_name
data_type: text
versions:
- v: 1
- v: 2
columns:
- name: id
data_type: int
constraints:
- type: not_null
- type: primary_key
- name: user_name
data_type: text
"""

versioned_pk_mult_columns_schema_yml = """
models:
- name: foo
latest_version: 2
config:
materialized: table
contract:
enforced: true
columns:
- name: id
data_type: int
constraints:
- type: not_null
- type: primary_key
- name: user_name
data_type: text
versions:
- v: 1
- v: 2
columns:
- name: id
data_type: int
constraints:
- type: not_null
- type: primary_key
- name: user_name
data_type: text
constraints:
- type: primary_key
"""


class TestVersionedModelConstraints:
@pytest.fixture(scope="class")
Expand Down Expand Up @@ -74,3 +142,85 @@ def test_versioned_model_constraints(self, project):
model_node = manifest.nodes["model.test.foo.v1"]
assert model_node.contract.enforced is True
assert len(model_node.constraints) == 1


# test primary key defined across model and column level constraints, expect error
class TestPrimaryKeysModelAndColumnLevelConstraints:
@pytest.fixture(scope="class")
def models(self):
return {
"foo.sql": foo_sql,
"schema.yml": schema_yml,
}

def test_model_column_pk_error(self, project):
results = run_dbt(["run"])
assert len(results) == 1
manifest = get_manifest(project.project_root)
model_node = manifest.nodes["model.test.foo"]
assert len(model_node.constraints) == 1

# remove foo.sql and create foo_v1.sql
rm_file(project.project_root, "models", "foo.sql")
write_file(foo_sql, project.project_root, "models", "foo_v1.sql")
write_file(versioned_schema_yml, project.project_root, "models", "schema.yml")
results = run_dbt(["run"])
assert len(results) == 1

manifest = get_manifest(project.project_root)
model_node = manifest.nodes["model.test.foo.v1"]
assert model_node.contract.enforced is True
assert len(model_node.constraints) == 1

# add foo_v2.sql
write_file(foo_sql, project.project_root, "models", "foo_v2.sql")
write_file(
versioned_pk_model_column_schema_yml, project.project_root, "models", "schema.yml"
)

expected_error = "Primary key constraints defined at the model level and the columns level"
with pytest.raises(ParsingError) as exc_info:
run_dbt(["run"])
assert expected_error in str(exc_info.value)


# test primary key defined across multiple columns, expect error
class TestPrimaryKeysMultipleColumns:
@pytest.fixture(scope="class")
def models(self):
return {
"foo.sql": foo_sql,
"schema.yml": schema_yml,
}

def test_pk_multiple_columns(self, project):
results = run_dbt(["run"])
assert len(results) == 1
manifest = get_manifest(project.project_root)
model_node = manifest.nodes["model.test.foo"]
assert len(model_node.constraints) == 1

# remove foo.sql and create foo_v1.sql
rm_file(project.project_root, "models", "foo.sql")
write_file(foo_sql, project.project_root, "models", "foo_v1.sql")
write_file(versioned_schema_yml, project.project_root, "models", "schema.yml")
results = run_dbt(["run"])
assert len(results) == 1

manifest = get_manifest(project.project_root)
model_node = manifest.nodes["model.test.foo.v1"]
assert model_node.contract.enforced is True
assert len(model_node.constraints) == 1

# add foo_v2.sql
write_file(foo_sql, project.project_root, "models", "foo_v2.sql")
write_file(
versioned_pk_mult_columns_schema_yml, project.project_root, "models", "schema.yml"
)

expected_error = (
"Found 2 columns (['id', 'user_name']) with primary key constraints defined"
)
with pytest.raises(ParsingError) as exc_info:
run_dbt(["run"])
assert expected_error in str(exc_info.value)

0 comments on commit ce10240

Please sign in to comment.