From c2a29569fd732a0e5800949a60a95271579e0593 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Fri, 5 Jul 2024 14:45:21 -0400 Subject: [PATCH 01/12] parse + compile constraint.to on fk constraints --- core/dbt/clients/jinja_static.py | 59 +++++++++++++++++++++++++++- core/dbt/compilation.py | 34 ++++++++++++++++ core/dbt/contracts/graph/manifest.py | 27 +++++++++++++ core/dbt/parser/schemas.py | 2 +- dev-requirements.txt | 2 +- 5 files changed, 120 insertions(+), 4 deletions(-) diff --git a/core/dbt/clients/jinja_static.py b/core/dbt/clients/jinja_static.py index 8e0c34df2e6..16ebd2044fc 100644 --- a/core/dbt/clients/jinja_static.py +++ b/core/dbt/clients/jinja_static.py @@ -1,11 +1,13 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple import jinja2 -from dbt.exceptions import MacroNamespaceNotStringError +from dbt.artifacts.resources import RefArgs +from dbt.exceptions import MacroNamespaceNotStringError, ParsingError from dbt_common.clients.jinja import get_environment from dbt_common.exceptions.macros import MacroNameNotStringError from dbt_common.tests import test_caching_enabled +from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore _TESTING_MACRO_CACHE: Optional[Dict[str, Any]] = {} @@ -153,3 +155,56 @@ def statically_parse_adapter_dispatch(func_call, ctx, db_wrapper): possible_macro_calls.append(f"{package_name}.{func_name}") return possible_macro_calls + + +def statically_parse_ref(input: str) -> RefArgs: + """ + Returns a RefArgs object corresponding to an input jinja expression. + + input: str representing how input node is referenced in tested model sql + * examples: + - "ref('my_model_a')" + - "ref('my_model_a', version=3)" + - "ref('package', 'my_model_a', version=3)" + + If input is not a well-formed jinja expression, TODO is raised. + If input is not a valid ref expression, TODO is raised. + """ + try: + statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}") + except ExtractionError: + # TODO: more precise error + raise ParsingError(f"Invalid jinja expression: {input}.") + + if not statically_parsed["refs"]: + # TODO: more precise error class + raise ParsingError("not a ref") + + ref = list(statically_parsed["refs"])[0] + return RefArgs(package=ref.get("package"), name=ref.get("name"), version=ref.get("version")) + + +def statically_parse_source(input: str) -> Tuple[str, str]: + """ + Returns a RefArgs object corresponding to an input jinja expression. + + input: str representing how input node is referenced in tested model sql + * examples: + - "source('my_source_schema', 'my_source_name')" + + If input is not a well-formed jinja expression, TODO is raised. + If input is not a valid source expression, TODO is raised. + """ + try: + statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}") + except ExtractionError: + # TODO: more precise error + raise ParsingError(f"Invalid jinja expression: {input}.") + + if not statically_parsed["sources"]: + # TODO: more precise error class + raise ParsingError("not a ref") + + source = list(statically_parsed["sources"])[0] + source_name, source_table_name = source + return source_name, source_table_name diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index d03407b2a4c..d662d9c8e49 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -35,6 +35,7 @@ from dbt.graph import Graph from dbt.node_types import ModelLanguage, NodeType from dbt_common.clients.system import make_directory +from dbt_common.contracts.constraints import ConstraintType from dbt_common.events.contextvars import get_node_info from dbt_common.events.format import pluralize from dbt_common.events.functions import fire_event @@ -437,8 +438,41 @@ def _compile_code( relation_name = str(relation_cls.create_from(self.config, node)) node.relation_name = relation_name + # Compile 'ref' and 'source' expressions in foreign key constraints + if node.resource_type == NodeType.Model: + # column-level foreign key constraints + for column in node.columns.values(): + for column_constraint in column.constraints: + if ( + column_constraint.type == ConstraintType.foreign_key + and column_constraint.to + ): + column_constraint.to = ( + self._compile_relation_for_foreign_key_constraint_to( + manifest, node, column_constraint.to + ) + ) + + # model-level foreign key constraints + for model_constraint in node.constraints: + if model_constraint.type == ConstraintType.foreign_key and model_constraint.to: + model_constraint.to = self._compile_relation_for_foreign_key_constraint_to( + manifest, node, model_constraint.to + ) + return node + def _compile_relation_for_foreign_key_constraint_to( + self, manifest: Manifest, node: ManifestSQLNode, to_expression: str + ) -> str: + foreign_key_node = manifest.find_node_from_ref_or_source(to_expression) + if not foreign_key_node: + raise GraphDependencyNotFoundError(node, to_expression) + adapter = get_adapter(self.config) + relation_cls = adapter.Relation + relation_name = str(relation_cls.create_from(self.config, foreign_key_node)) + return relation_name + # This method doesn't actually "compile" any of the nodes. That is done by the # "compile_node" method. This creates a Linker and builds the networkx graph, # writes out the graph.gpickle file, and prints the stats, returning a Graph object. diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 2bd183af759..df9b8c21ed3 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -35,6 +35,7 @@ from dbt.artifacts.resources import BaseResource, DeferRelation, NodeVersion from dbt.artifacts.resources.v1.config import NodeConfig from dbt.artifacts.schemas.manifest import ManifestMetadata, UniqueID, WritableManifest +from dbt.clients.jinja_static import statically_parse_ref, statically_parse_source from dbt.contracts.files import ( AnySourceFile, FileHash, @@ -1634,6 +1635,32 @@ def add_saved_query(self, source_file: SchemaSourceFile, saved_query: SavedQuery # end of methods formerly in ParseResult + def find_node_from_ref_or_source( + self, expression: str + ) -> Optional[Union[ModelNode, SourceDefinition]]: + valid_ref = True + valid_source = True + try: + ref = statically_parse_ref(expression) + # TODO: better error handling + except Exception: + valid_ref = False + try: + source_name, source_table_name = statically_parse_source(expression) + # TODO: better error handling + except Exception: + valid_source = False + + if not valid_ref and not valid_ref: + raise CompilationError(f"Invalid ref or source expression: {expression}") + + if valid_ref: + node = self.ref_lookup.find(ref.name, ref.package, ref.version, self) + elif valid_source: + node = self.source_lookup.find(f"{source_name}.{source_table_name}", None, self) + + return node + # Provide support for copy.deepcopy() - we just need to avoid the lock! # pickle and deepcopy use this. It returns a callable object used to # create the initial version of the object and a tuple of arguments diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index 284a01fc58e..7209de6e20d 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -906,7 +906,7 @@ def patch_node_properties(self, node, patch: "ParsedNodePatch"): self.patch_constraints(node, patch.constraints) node.build_contract_checksum() - def patch_constraints(self, node, constraints): + def patch_constraints(self, node, constraints: List[Dict[str, Any]]): contract_config = node.config.get("contract") if contract_config.enforced is True: self._validate_constraint_prerequisites(node) diff --git a/dev-requirements.txt b/dev-requirements.txt index 8541133ff9a..c42d44b89cf 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,6 +1,6 @@ git+https://github.com/dbt-labs/dbt-adapters.git@main git+https://github.com/dbt-labs/dbt-adapters.git@main#subdirectory=dbt-tests-adapter -git+https://github.com/dbt-labs/dbt-common.git@main +git+https://github.com/dbt-labs/dbt-common.git@foreign-ref-column-constraint git+https://github.com/dbt-labs/dbt-postgres.git@main # black must match what's in .pre-commit-config.yaml to be sure local env matches CI black==22.3.0 From c8d04154860546732d6ddc285289bbcd368d5b04 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Tue, 16 Jul 2024 21:04:32 -0400 Subject: [PATCH 02/12] restore dev-requirements.txt --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index c42d44b89cf..8541133ff9a 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,6 +1,6 @@ git+https://github.com/dbt-labs/dbt-adapters.git@main git+https://github.com/dbt-labs/dbt-adapters.git@main#subdirectory=dbt-tests-adapter -git+https://github.com/dbt-labs/dbt-common.git@foreign-ref-column-constraint +git+https://github.com/dbt-labs/dbt-common.git@main git+https://github.com/dbt-labs/dbt-postgres.git@main # black must match what's in .pre-commit-config.yaml to be sure local env matches CI black==22.3.0 From f247b55b536a92d2e54f4f5eee8c65a587db0617 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Fri, 19 Jul 2024 15:11:03 -0400 Subject: [PATCH 03/12] clean up error handling --- core/dbt/clients/jinja_static.py | 16 ++++++---------- core/dbt/compilation.py | 8 +++++++- core/dbt/contracts/graph/manifest.py | 9 ++++----- core/dbt/exceptions.py | 12 ++++++++++++ 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/core/dbt/clients/jinja_static.py b/core/dbt/clients/jinja_static.py index 16ebd2044fc..a57c96642a1 100644 --- a/core/dbt/clients/jinja_static.py +++ b/core/dbt/clients/jinja_static.py @@ -173,12 +173,10 @@ def statically_parse_ref(input: str) -> RefArgs: try: statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}") except ExtractionError: - # TODO: more precise error - raise ParsingError(f"Invalid jinja expression: {input}.") + raise ParsingError(f"Invalid jinja expression: {input}") - if not statically_parsed["refs"]: - # TODO: more precise error class - raise ParsingError("not a ref") + if not statically_parsed.get("refs"): + raise ParsingError(f"Invalid ref expression: {input}") ref = list(statically_parsed["refs"])[0] return RefArgs(package=ref.get("package"), name=ref.get("name"), version=ref.get("version")) @@ -198,12 +196,10 @@ def statically_parse_source(input: str) -> Tuple[str, str]: try: statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}") except ExtractionError: - # TODO: more precise error - raise ParsingError(f"Invalid jinja expression: {input}.") + raise ParsingError(f"Invalid jinja expression: {input}") - if not statically_parsed["sources"]: - # TODO: more precise error class - raise ParsingError("not a ref") + if not statically_parsed.get("sources"): + raise ParsingError(f"Invalid source expression: {input}") source = list(statically_parsed["sources"])[0] source_name, source_table_name = source diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index d662d9c8e49..5c20f7fd641 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -29,7 +29,9 @@ from dbt.exceptions import ( DbtInternalError, DbtRuntimeError, + ForeignKeyConstraintToSyntaxError, GraphDependencyNotFoundError, + ParsingError, ) from dbt.flags import get_flags from dbt.graph import Graph @@ -465,7 +467,11 @@ def _compile_code( def _compile_relation_for_foreign_key_constraint_to( self, manifest: Manifest, node: ManifestSQLNode, to_expression: str ) -> str: - foreign_key_node = manifest.find_node_from_ref_or_source(to_expression) + try: + foreign_key_node = manifest.find_node_from_ref_or_source(to_expression) + except ParsingError: + raise ForeignKeyConstraintToSyntaxError(node, to_expression) + if not foreign_key_node: raise GraphDependencyNotFoundError(node, to_expression) adapter = get_adapter(self.config) diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index b6d445ce05d..587c14a8046 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -70,6 +70,7 @@ AmbiguousResourceNameRefError, CompilationError, DuplicateResourceNameError, + ParsingError, ) from dbt.flags import get_flags from dbt.mp_context import get_mp_context @@ -1643,17 +1644,15 @@ def find_node_from_ref_or_source( valid_source = True try: ref = statically_parse_ref(expression) - # TODO: better error handling - except Exception: + except ParsingError: valid_ref = False try: source_name, source_table_name = statically_parse_source(expression) - # TODO: better error handling - except Exception: + except ParsingError: valid_source = False if not valid_ref and not valid_ref: - raise CompilationError(f"Invalid ref or source expression: {expression}") + raise ParsingError(f"Invalid ref or source syntax: {expression}.") if valid_ref: node = self.ref_lookup.find(ref.name, ref.package, ref.version, self) diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index aec2b5e3826..27aa863fd17 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -136,6 +136,18 @@ def get_message(self) -> str: return msg +class ForeignKeyConstraintToSyntaxError(CompilationError): + def __init__(self, node, expression: str) -> None: + self.expression = expression + self.node = node + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + msg = f"'{self.node.unique_id}' defines a foreign key constraint 'to' expression which is not valid 'ref' or 'source' syntax: {self.expression}." + + return msg + + # client level exceptions From 6a8a522094cf33cbf5e2eb35321b537c157ac142 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Fri, 19 Jul 2024 16:18:46 -0400 Subject: [PATCH 04/12] changelog entry --- .changes/unreleased/Features-20240719-161841.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Features-20240719-161841.yaml diff --git a/.changes/unreleased/Features-20240719-161841.yaml b/.changes/unreleased/Features-20240719-161841.yaml new file mode 100644 index 00000000000..a84a9d45e9d --- /dev/null +++ b/.changes/unreleased/Features-20240719-161841.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support ref and source in foreign key constraint expressions +time: 2024-07-19T16:18:41.434278-04:00 +custom: + Author: michelleark + Issue: "8062" From ff096d7df975daff36bde810936fa68e8bb73346 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Fri, 19 Jul 2024 16:20:37 -0400 Subject: [PATCH 05/12] cleanup docstrings --- core/dbt/clients/jinja_static.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/core/dbt/clients/jinja_static.py b/core/dbt/clients/jinja_static.py index a57c96642a1..b5ae181ab20 100644 --- a/core/dbt/clients/jinja_static.py +++ b/core/dbt/clients/jinja_static.py @@ -167,8 +167,7 @@ def statically_parse_ref(input: str) -> RefArgs: - "ref('my_model_a', version=3)" - "ref('package', 'my_model_a', version=3)" - If input is not a well-formed jinja expression, TODO is raised. - If input is not a valid ref expression, TODO is raised. + If input is not a well-formed jinja ref expression, a ParsingError is raised. """ try: statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}") @@ -190,8 +189,7 @@ def statically_parse_source(input: str) -> Tuple[str, str]: * examples: - "source('my_source_schema', 'my_source_name')" - If input is not a well-formed jinja expression, TODO is raised. - If input is not a valid source expression, TODO is raised. + If input is not a well-formed jinja source expression, ParsingError is raised. """ try: statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}") From 8591dd44ed5dd1911642329ac2173005253d6c57 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Fri, 19 Jul 2024 16:31:23 -0400 Subject: [PATCH 06/12] migrate test_jinja_static from unittest --- tests/unit/clients/test_jinja_static.py | 68 ++++++++++++++----------- 1 file changed, 39 insertions(+), 29 deletions(-) diff --git a/tests/unit/clients/test_jinja_static.py b/tests/unit/clients/test_jinja_static.py index d575cfb76e8..e09097c2026 100644 --- a/tests/unit/clients/test_jinja_static.py +++ b/tests/unit/clients/test_jinja_static.py @@ -1,44 +1,54 @@ -import unittest +import pytest from dbt.clients.jinja_static import statically_extract_macro_calls from dbt.context.base import generate_base_context -class MacroCalls(unittest.TestCase): - def setUp(self): - self.macro_strings = [ +@pytest.mark.parametrize( + "macro_string,expected_possible_macro_calls", + [ + ( "{% macro parent_macro() %} {% do return(nested_macro()) %} {% endmacro %}", - "{% macro lr_macro() %} {{ return(load_result('relations').table) }} {% endmacro %}", - "{% macro get_snapshot_unique_id() -%} {{ return(adapter.dispatch('get_snapshot_unique_id')()) }} {%- endmacro %}", - "{% macro get_columns_in_query(select_sql) -%} {{ return(adapter.dispatch('get_columns_in_query')(select_sql)) }} {% endmacro %}", - """{% macro test_mutually_exclusive_ranges(model) %} - with base as ( - select {{ get_snapshot_unique_id() }} as dbt_unique_id, - * - from {{ model }} ) - {% endmacro %}""", - "{% macro test_my_test(model) %} select {{ current_timestamp_backcompat() }} {% endmacro %}", - "{% macro some_test(model) -%} {{ return(adapter.dispatch('test_some_kind4', 'foo_utils4')) }} {%- endmacro %}", - "{% macro some_test(model) -%} {{ return(adapter.dispatch('test_some_kind5', macro_namespace = 'foo_utils5')) }} {%- endmacro %}", - ] - - self.possible_macro_calls = [ ["nested_macro"], + ), + ( + "{% macro lr_macro() %} {{ return(load_result('relations').table) }} {% endmacro %}", ["load_result"], + ), + ( + "{% macro get_snapshot_unique_id() -%} {{ return(adapter.dispatch('get_snapshot_unique_id')()) }} {%- endmacro %}", ["get_snapshot_unique_id"], + ), + ( + "{% macro get_columns_in_query(select_sql) -%} {{ return(adapter.dispatch('get_columns_in_query')(select_sql)) }} {% endmacro %}", ["get_columns_in_query"], + ), + ( + """{% macro test_mutually_exclusive_ranges(model) %} + with base as ( + select {{ get_snapshot_unique_id() }} as dbt_unique_id, + * + from {{ model }} ) + {% endmacro %}""", ["get_snapshot_unique_id"], + ), + ( + "{% macro test_my_test(model) %} select {{ current_timestamp_backcompat() }} {% endmacro %}", ["current_timestamp_backcompat"], + ), + ( + "{% macro some_test(model) -%} {{ return(adapter.dispatch('test_some_kind4', 'foo_utils4')) }} {%- endmacro %}", ["test_some_kind4", "foo_utils4.test_some_kind4"], + ), + ( + "{% macro some_test(model) -%} {{ return(adapter.dispatch('test_some_kind5', macro_namespace = 'foo_utils5')) }} {%- endmacro %}", ["test_some_kind5", "foo_utils5.test_some_kind5"], - ] - - def test_macro_calls(self): - cli_vars = {"local_utils_dispatch_list": ["foo_utils4"]} - ctx = generate_base_context(cli_vars) + ), + ], +) +def test_extract_macro_calls(self, macro_string, expected_possible_macro_calls): + cli_vars = {"local_utils_dispatch_list": ["foo_utils4"]} + ctx = generate_base_context(cli_vars) - index = 0 - for macro_string in self.macro_strings: - possible_macro_calls = statically_extract_macro_calls(macro_string, ctx) - self.assertEqual(self.possible_macro_calls[index], possible_macro_calls) - index += 1 + possible_macro_calls = statically_extract_macro_calls(macro_string, ctx) + assert possible_macro_calls == expected_possible_macro_calls From 30935003f3d70bd6c23cd4b913075c37d553abc8 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Fri, 19 Jul 2024 16:42:15 -0400 Subject: [PATCH 07/12] add unit tests for statically_parse_ref and statically_parse_source --- tests/unit/clients/test_jinja_static.py | 41 +++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/tests/unit/clients/test_jinja_static.py b/tests/unit/clients/test_jinja_static.py index e09097c2026..52dc1f671f0 100644 --- a/tests/unit/clients/test_jinja_static.py +++ b/tests/unit/clients/test_jinja_static.py @@ -1,7 +1,13 @@ import pytest -from dbt.clients.jinja_static import statically_extract_macro_calls +from dbt.artifacts.resources import RefArgs +from dbt.clients.jinja_static import ( + statically_extract_macro_calls, + statically_parse_ref, + statically_parse_source, +) from dbt.context.base import generate_base_context +from dbt.exceptions import ParsingError @pytest.mark.parametrize( @@ -46,9 +52,40 @@ ), ], ) -def test_extract_macro_calls(self, macro_string, expected_possible_macro_calls): +def test_extract_macro_calls(macro_string, expected_possible_macro_calls): cli_vars = {"local_utils_dispatch_list": ["foo_utils4"]} ctx = generate_base_context(cli_vars) possible_macro_calls = statically_extract_macro_calls(macro_string, ctx) assert possible_macro_calls == expected_possible_macro_calls + + +class TestStaticallyParseRef: + @pytest.mark.parametrize("invalid_expression", ["invalid", "source('schema', 'table')"]) + def test_invalid_expression(self, invalid_expression): + with pytest.raises(ParsingError): + statically_parse_ref(invalid_expression) + + @pytest.mark.parametrize( + "ref_expression,expected_ref_args", + [ + ("ref('model')", RefArgs(name="model")), + ("ref('package','model')", RefArgs(name="model", package="package")), + ("ref('model',v=3)", RefArgs(name="model", version=3)), + ("ref('package','model',v=3)", RefArgs(name="model", package="package", version=3)), + ], + ) + def test_valid_ref_expression(self, ref_expression, expected_ref_args): + ref_args = statically_parse_ref(ref_expression) + assert ref_args == expected_ref_args + + +class TestStaticallyParseSource: + @pytest.mark.parametrize("invalid_expression", ["invalid", "ref('package', 'model')"]) + def test_invalid_expression(self, invalid_expression): + with pytest.raises(ParsingError): + statically_parse_source(invalid_expression) + + def test_valid_ref_expression(self): + parsed_source = statically_parse_source("source('schema', 'table')") + assert parsed_source == ("schema", "table") From e233fc56bb86a1438c5fc9c2f97e89367c1245c0 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Fri, 19 Jul 2024 17:47:19 -0400 Subject: [PATCH 08/12] add unit tests manifest.find_node_from_ref_or_source --- core/dbt/contracts/graph/manifest.py | 3 +- tests/unit/contracts/graph/test_manifest.py | 52 ++++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 587c14a8046..eb7c5c1b887 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -1651,9 +1651,10 @@ def find_node_from_ref_or_source( except ParsingError: valid_source = False - if not valid_ref and not valid_ref: + if not valid_ref and not valid_source: raise ParsingError(f"Invalid ref or source syntax: {expression}.") + node = None if valid_ref: node = self.ref_lookup.find(ref.name, ref.package, ref.version, self) elif valid_source: diff --git a/tests/unit/contracts/graph/test_manifest.py b/tests/unit/contracts/graph/test_manifest.py index 35e96308da7..dc81fa4b7dc 100644 --- a/tests/unit/contracts/graph/test_manifest.py +++ b/tests/unit/contracts/graph/test_manifest.py @@ -37,7 +37,7 @@ SeedNode, SourceDefinition, ) -from dbt.exceptions import AmbiguousResourceNameRefError +from dbt.exceptions import AmbiguousResourceNameRefError, ParsingError from dbt.flags import set_from_args from dbt.node_types import NodeType from dbt_common.events.functions import reset_metadata_vars @@ -1962,3 +1962,53 @@ def test_resolve_doc(docs, package, expected): expected_package, expected_name = expected assert result.name == expected_name assert result.package_name == expected_package + + +class TestManifestFindNodeFromRefOrSource: + @pytest.fixture + def mock_node(self): + return MockNode("my_package", "my_model") + + @pytest.fixture + def mock_disabled_node(self): + return MockNode("my_package", "disabled_node", config={"enabled": False}) + + @pytest.fixture + def mock_source(self): + return MockSource("root", "my_source", "source_table") + + @pytest.fixture + def mock_disabled_source(self): + return MockSource("root", "my_source", "disabled_source_table", config={"enabled": False}) + + @pytest.fixture + def mock_manifest(self, mock_node, mock_source, mock_disabled_node, mock_disabled_source): + return make_manifest( + nodes=[mock_node, mock_disabled_node], sources=[mock_source, mock_disabled_source] + ) + + @pytest.mark.parametrize( + "expression,expected_node", + [ + ("ref('my_package', 'my_model')", "mock_node"), + ("ref('my_package', 'doesnt_exist')", None), + ("ref('my_package', 'disabled_node')", "mock_disabled_node"), + ("source('my_source', 'source_table')", "mock_source"), + ("source('my_source', 'doesnt_exist')", None), + ("source('my_source', 'disabled_source_table')", "mock_disabled_source"), + ], + ) + def test_find_node_from_ref_or_source(self, expression, expected_node, mock_manifest, request): + node = mock_manifest.find_node_from_ref_or_source(expression) + + if expected_node is None: + assert node is None + else: + assert node == request.getfixturevalue(expected_node) + + @pytest.mark.parametrize("invalid_expression", ["invalid", "ref(')"]) + def test_find_node_from_ref_or_source_invalid_expression( + self, invalid_expression, mock_manifest + ): + with pytest.raises(ParsingError): + mock_manifest.find_node_from_ref_or_source(invalid_expression) From 892ed9e0107125fec37146689138be94bbeec531 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Mon, 22 Jul 2024 11:53:29 -0400 Subject: [PATCH 09/12] add functional tests for foreign key constraint parsing and compilation --- tests/functional/constraints/fixtures.py | 115 ++++++++++ .../test_foreign_key_constraints.py | 215 ++++++++++++++++++ 2 files changed, 330 insertions(+) create mode 100644 tests/functional/constraints/fixtures.py create mode 100644 tests/functional/constraints/test_foreign_key_constraints.py diff --git a/tests/functional/constraints/fixtures.py b/tests/functional/constraints/fixtures.py new file mode 100644 index 00000000000..de60963bfec --- /dev/null +++ b/tests/functional/constraints/fixtures.py @@ -0,0 +1,115 @@ +model_foreign_key_model_schema_yml = """ +models: + - name: my_model + constraints: + - type: foreign_key + columns: [id] + to: ref('my_model_to') + to_columns: [id] + columns: + - name: id + data_type: integer +""" + + +model_foreign_key_source_schema_yml = """ +sources: + - name: test_source + tables: + - name: test_table + +models: + - name: my_model + constraints: + - type: foreign_key + columns: [id] + to: source('test_source', 'test_table') + to_columns: [id] + columns: + - name: id + data_type: integer +""" + + +model_foreign_key_model_node_not_found_schema_yml = """ +models: + - name: my_model + constraints: + - type: foreign_key + columns: [id] + to: ref('doesnt_exist') + to_columns: [id] + columns: + - name: id + data_type: integer +""" + + +model_foreign_key_model_invalid_syntax_schema_yml = """ +models: + - name: my_model + constraints: + - type: foreign_key + columns: [id] + to: invalid + to_columns: [id] + columns: + - name: id + data_type: integer +""" + + +model_foreign_key_model_column_schema_yml = """ +models: + - name: my_model + columns: + - name: id + data_type: integer + constraints: + - type: foreign_key + to: ref('my_model_to') + to_columns: [id] +""" + + +model_foreign_key_column_invalid_syntax_schema_yml = """ +models: + - name: my_model + columns: + - name: id + data_type: integer + constraints: + - type: foreign_key + to: invalid + to_columns: [id] +""" + + +model_foreign_key_column_node_not_found_schema_yml = """ +models: + - name: my_model + columns: + - name: id + data_type: integer + constraints: + - type: foreign_key + to: ref('doesnt_exist') + to_columns: [id] +""" + +model_column_level_foreign_key_source_schema_yml = """ +sources: + - name: test_source + tables: + - name: test_table + +models: + - name: my_model + columns: + - name: id + data_type: integer + constraints: + - type: foreign_key + to: source('test_source', 'test_table') + to_columns: [id] +""" diff --git a/tests/functional/constraints/test_foreign_key_constraints.py b/tests/functional/constraints/test_foreign_key_constraints.py new file mode 100644 index 00000000000..a7c173ac659 --- /dev/null +++ b/tests/functional/constraints/test_foreign_key_constraints.py @@ -0,0 +1,215 @@ +import pytest + +from dbt.exceptions import DbtRuntimeError +from dbt.tests.util import get_artifact, run_dbt +from dbt_common.contracts.constraints import ( + ColumnLevelConstraint, + ConstraintType, + ModelLevelConstraint, +) +from tests.functional.constraints.fixtures import ( + model_column_level_foreign_key_source_schema_yml, + model_foreign_key_column_invalid_syntax_schema_yml, + model_foreign_key_column_node_not_found_schema_yml, + model_foreign_key_model_column_schema_yml, + model_foreign_key_model_invalid_syntax_schema_yml, + model_foreign_key_model_node_not_found_schema_yml, + model_foreign_key_model_schema_yml, + model_foreign_key_source_schema_yml, +) + + +class TestModelLevelForeignKeyConstraintToRef: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_foreign_key_model_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_model_level_fk_to(self, project, unique_schema): + manifest = run_dbt(["parse"]) + assert len(manifest.nodes["model.test.my_model"].constraints) == 1 + + parsed_constraint = manifest.nodes["model.test.my_model"].constraints[0] + assert parsed_constraint == ModelLevelConstraint( + type=ConstraintType.foreign_key, + columns=["id"], + to="ref('my_model_to')", + to_columns=["id"], + ) + + # Assert compilation renders to from 'ref' to relation identifer + run_dbt(["compile"]) + manifest = get_artifact(project.project_root, "target", "manifest.json") + assert len(manifest["nodes"]["model.test.my_model"]["constraints"]) == 1 + + compiled_constraint = manifest["nodes"]["model.test.my_model"]["constraints"][0] + assert compiled_constraint["to"] == f'"dbt"."{unique_schema}"."my_model_to"' + # Other constraint fields should remain as parsed + assert compiled_constraint["to_columns"] == parsed_constraint.to_columns + assert compiled_constraint["columns"] == parsed_constraint.columns + assert compiled_constraint["type"] == parsed_constraint.type + + +class TestModelLevelForeignKeyConstraintToSource: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_foreign_key_source_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_model_level_fk_to(self, project, unique_schema): + manifest = run_dbt(["parse"]) + assert len(manifest.nodes["model.test.my_model"].constraints) == 1 + + parsed_constraint = manifest.nodes["model.test.my_model"].constraints[0] + assert parsed_constraint == ModelLevelConstraint( + type=ConstraintType.foreign_key, + columns=["id"], + to="source('test_source', 'test_table')", + to_columns=["id"], + ) + + # Assert compilation renders to from 'ref' to relation identifer + run_dbt(["compile"]) + manifest = get_artifact(project.project_root, "target", "manifest.json") + assert len(manifest["nodes"]["model.test.my_model"]["constraints"]) == 1 + + compiled_constraint = manifest["nodes"]["model.test.my_model"]["constraints"][0] + assert compiled_constraint["to"] == '"dbt"."test_source"."test_table"' + # Other constraint fields should remain as parsed + assert compiled_constraint["to_columns"] == parsed_constraint.to_columns + assert compiled_constraint["columns"] == parsed_constraint.columns + assert compiled_constraint["type"] == parsed_constraint.type + + +class TestModelLevelForeignKeyConstraintRefNotFoundError: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_foreign_key_model_node_not_found_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_model_level_fk_to_doesnt_exist(self, project): + with pytest.raises(DbtRuntimeError, match="not in the graph"): + run_dbt(["compile"]) + + +class TestModelLevelForeignKeyConstraintRefSyntaxError: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_foreign_key_model_invalid_syntax_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_model_level_fk_to(self, project): + with pytest.raises( + DbtRuntimeError, + match="'model.test.my_model' defines a foreign key constraint 'to' expression which is not valid 'ref' or 'source' syntax", + ): + run_dbt(["compile"]) + + +class TestColumnLevelForeignKeyConstraintToRef: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_foreign_key_model_column_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_column_level_fk_to(self, project, unique_schema): + manifest = run_dbt(["parse"]) + assert len(manifest.nodes["model.test.my_model"].columns["id"].constraints) == 1 + + parsed_constraint = manifest.nodes["model.test.my_model"].columns["id"].constraints[0] + assert parsed_constraint == ColumnLevelConstraint( + type=ConstraintType.foreign_key, to="ref('my_model_to')", to_columns=["id"] + ) + + # Assert compilation renders to from 'ref' to relation identifer + run_dbt(["compile"]) + manifest = get_artifact(project.project_root, "target", "manifest.json") + assert len(manifest["nodes"]["model.test.my_model"]["columns"]["id"]["constraints"]) == 1 + + compiled_constraint = manifest["nodes"]["model.test.my_model"]["columns"]["id"][ + "constraints" + ][0] + assert compiled_constraint["to"] == f'"dbt"."{unique_schema}"."my_model_to"' + # Other constraint fields should remain as parsed + assert compiled_constraint["to_columns"] == parsed_constraint.to_columns + assert compiled_constraint["type"] == parsed_constraint.type + + +class TestColumnLevelForeignKeyConstraintToSource: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_column_level_foreign_key_source_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_model_level_fk_to(self, project, unique_schema): + manifest = run_dbt(["parse"]) + assert len(manifest.nodes["model.test.my_model"].columns["id"].constraints) == 1 + + parsed_constraint = manifest.nodes["model.test.my_model"].columns["id"].constraints[0] + assert parsed_constraint == ColumnLevelConstraint( + type=ConstraintType.foreign_key, + to="source('test_source', 'test_table')", + to_columns=["id"], + ) + + # Assert compilation renders to from 'ref' to relation identifer + run_dbt(["compile"]) + manifest = get_artifact(project.project_root, "target", "manifest.json") + assert len(manifest["nodes"]["model.test.my_model"]["columns"]["id"]["constraints"]) == 1 + + compiled_constraint = manifest["nodes"]["model.test.my_model"]["columns"]["id"][ + "constraints" + ][0] + assert compiled_constraint["to"] == '"dbt"."test_source"."test_table"' + # # Other constraint fields should remain as parsed + assert compiled_constraint["to_columns"] == parsed_constraint.to_columns + assert compiled_constraint["type"] == parsed_constraint.type + + +class TestColumnLevelForeignKeyConstraintRefNotFoundError: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_foreign_key_column_node_not_found_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_model_level_fk_to_doesnt_exist(self, project): + with pytest.raises(DbtRuntimeError, match="not in the graph"): + run_dbt(["compile"]) + + +class TestColumnLevelForeignKeyConstraintRefSyntaxError: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_foreign_key_column_invalid_syntax_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_model_level_fk_to(self, project): + with pytest.raises( + DbtRuntimeError, + match="'model.test.my_model' defines a foreign key constraint 'to' expression which is not valid 'ref' or 'source' syntax", + ): + run_dbt(["compile"]) From 626421d85ccfdf2d4dd6165c69f66a14fc5d9a05 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Mon, 22 Jul 2024 18:10:29 -0400 Subject: [PATCH 10/12] include ref/sources from fk constraints in node depends_on --- core/dbt/clients/jinja_static.py | 26 +++++++- core/dbt/compilation.py | 26 ++------ core/dbt/contracts/graph/manifest.py | 28 +++----- core/dbt/contracts/graph/nodes.py | 18 +++++- core/dbt/parser/schemas.py | 22 +++++++ dev-requirements.txt | 2 +- .../test_foreign_key_constraints.py | 64 +++++++++++++------ tests/unit/clients/test_jinja_static.py | 23 ++++++- tests/unit/graph/test_nodes.py | 42 ++++++++++++ 9 files changed, 187 insertions(+), 64 deletions(-) diff --git a/core/dbt/clients/jinja_static.py b/core/dbt/clients/jinja_static.py index b5ae181ab20..6082f03f80c 100644 --- a/core/dbt/clients/jinja_static.py +++ b/core/dbt/clients/jinja_static.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Union import jinja2 @@ -181,7 +181,7 @@ def statically_parse_ref(input: str) -> RefArgs: return RefArgs(package=ref.get("package"), name=ref.get("name"), version=ref.get("version")) -def statically_parse_source(input: str) -> Tuple[str, str]: +def statically_parse_source(input: str) -> List[str]: """ Returns a RefArgs object corresponding to an input jinja expression. @@ -201,4 +201,24 @@ def statically_parse_source(input: str) -> Tuple[str, str]: source = list(statically_parsed["sources"])[0] source_name, source_table_name = source - return source_name, source_table_name + return [source_name, source_table_name] + + +def statically_parse_ref_or_source(expression: str) -> Union[RefArgs, List[str]]: + ref_or_source: Union[RefArgs, List[str]] + valid_ref = True + valid_source = True + + try: + ref_or_source = statically_parse_ref(expression) + except ParsingError: + valid_ref = False + try: + ref_or_source = statically_parse_source(expression) + except ParsingError: + valid_source = False + + if not valid_ref and not valid_source: + raise ParsingError(f"Invalid ref or source syntax: {expression}.") + + return ref_or_source diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index 5c20f7fd641..47d7ffbdb51 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -442,24 +442,10 @@ def _compile_code( # Compile 'ref' and 'source' expressions in foreign key constraints if node.resource_type == NodeType.Model: - # column-level foreign key constraints - for column in node.columns.values(): - for column_constraint in column.constraints: - if ( - column_constraint.type == ConstraintType.foreign_key - and column_constraint.to - ): - column_constraint.to = ( - self._compile_relation_for_foreign_key_constraint_to( - manifest, node, column_constraint.to - ) - ) - - # model-level foreign key constraints - for model_constraint in node.constraints: - if model_constraint.type == ConstraintType.foreign_key and model_constraint.to: - model_constraint.to = self._compile_relation_for_foreign_key_constraint_to( - manifest, node, model_constraint.to + for constraint in node.all_constraints: + if constraint.type == ConstraintType.foreign_key and constraint.to: + constraint.to = self._compile_relation_for_foreign_key_constraint_to( + manifest, node, constraint.to ) return node @@ -474,9 +460,9 @@ def _compile_relation_for_foreign_key_constraint_to( if not foreign_key_node: raise GraphDependencyNotFoundError(node, to_expression) + adapter = get_adapter(self.config) - relation_cls = adapter.Relation - relation_name = str(relation_cls.create_from(self.config, foreign_key_node)) + relation_name = str(adapter.Relation.create_from(self.config, foreign_key_node)) return relation_name # This method doesn't actually "compile" any of the nodes. That is done by the diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index eb7c5c1b887..21c5571b74b 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -32,10 +32,10 @@ from dbt.adapters.factory import get_adapter_package_names # to preserve import paths -from dbt.artifacts.resources import BaseResource, DeferRelation, NodeVersion +from dbt.artifacts.resources import BaseResource, DeferRelation, NodeVersion, RefArgs from dbt.artifacts.resources.v1.config import NodeConfig from dbt.artifacts.schemas.manifest import ManifestMetadata, UniqueID, WritableManifest -from dbt.clients.jinja_static import statically_parse_ref, statically_parse_source +from dbt.clients.jinja_static import statically_parse_ref_or_source from dbt.contracts.files import ( AnySourceFile, FileHash, @@ -70,7 +70,6 @@ AmbiguousResourceNameRefError, CompilationError, DuplicateResourceNameError, - ParsingError, ) from dbt.flags import get_flags from dbt.mp_context import get_mp_context @@ -1640,24 +1639,15 @@ def add_saved_query(self, source_file: SchemaSourceFile, saved_query: SavedQuery def find_node_from_ref_or_source( self, expression: str ) -> Optional[Union[ModelNode, SourceDefinition]]: - valid_ref = True - valid_source = True - try: - ref = statically_parse_ref(expression) - except ParsingError: - valid_ref = False - try: - source_name, source_table_name = statically_parse_source(expression) - except ParsingError: - valid_source = False - - if not valid_ref and not valid_source: - raise ParsingError(f"Invalid ref or source syntax: {expression}.") + ref_or_source = statically_parse_ref_or_source(expression) node = None - if valid_ref: - node = self.ref_lookup.find(ref.name, ref.package, ref.version, self) - elif valid_source: + if isinstance(ref_or_source, RefArgs): + node = self.ref_lookup.find( + ref_or_source.name, ref_or_source.package, ref_or_source.version, self + ) + else: + source_name, source_table_name = ref_or_source[0], ref_or_source[1] node = self.source_lookup.find(f"{source_name}.{source_table_name}", None, self) return node diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index da42fb7d766..42d19e2c8dd 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -85,7 +85,11 @@ NodeType, ) from dbt_common.clients.system import write_file -from dbt_common.contracts.constraints import ConstraintType +from dbt_common.contracts.constraints import ( + ColumnLevelConstraint, + ConstraintType, + ModelLevelConstraint, +) from dbt_common.events.contextvars import set_log_contextvars from dbt_common.events.functions import warn_or_error @@ -489,6 +493,18 @@ def search_name(self): def materialization_enforces_constraints(self) -> bool: return self.config.materialized in ["table", "incremental"] + @property + def all_constraints(self) -> List[Union[ModelLevelConstraint, ColumnLevelConstraint]]: + constraints: List[Union[ModelLevelConstraint, ColumnLevelConstraint]] = [] + for model_level_constraint in self.constraints: + constraints.append(model_level_constraint) + + for column in self.columns.values(): + for column_level_constraint in column.constraints: + constraints.append(column_level_constraint) + + return constraints + def infer_primary_key(self, data_tests: List["GenericTestNode"]) -> List[str]: """ Infers the columns that can be used as primary key of a model in the following order: diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index 3b50ae0f98b..5e269fd385c 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -5,6 +5,8 @@ from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Type, TypeVar from dbt import deprecations +from dbt.artifacts.resources import RefArgs +from dbt.clients.jinja_static import statically_parse_ref_or_source from dbt.clients.yaml_helper import load_yaml_text from dbt.config import RuntimeConfig from dbt.context.configured import SchemaYamlVars, generate_schema_yml_context @@ -930,6 +932,26 @@ def patch_constraints(self, node, constraints: List[Dict[str, Any]]) -> None: self._validate_pk_constraints(node, constraints) node.constraints = [ModelLevelConstraint.from_dict(c) for c in constraints] + self._process_constraints_refs_and_sources(node) + + def _process_constraints_refs_and_sources(self, model_node: ModelNode) -> None: + """ + Populate model_node.refs and model_node.sources based on foreign-key constraint references, + whether defined at the model-level or column-level. + """ + for constraint in model_node.all_constraints: + if constraint.type == ConstraintType.foreign_key and constraint.to: + try: + ref_or_source = statically_parse_ref_or_source(constraint.to) + except ParsingError: + raise ParsingError( + f"Invalid 'ref' or 'source' syntax on foreign key constraint 'to' on model {model_node.name}: {constraint.to}." + ) + + if isinstance(ref_or_source, RefArgs): + model_node.refs.append(ref_or_source) + else: + model_node.sources.append(ref_or_source) def _validate_pk_constraints( self, model_node: ModelNode, constraints: List[Dict[str, Any]] diff --git a/dev-requirements.txt b/dev-requirements.txt index 20605e632b8..c5ddd0217e6 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,4 @@ -git+https://github.com/dbt-labs/dbt-adapters.git@main +git+https://github.com/dbt-labs/dbt-adapters.git@render-foreign-constraint-ref git+https://github.com/dbt-labs/dbt-adapters.git@main#subdirectory=dbt-tests-adapter git+https://github.com/dbt-labs/dbt-common.git@main git+https://github.com/dbt-labs/dbt-postgres.git@main diff --git a/tests/functional/constraints/test_foreign_key_constraints.py b/tests/functional/constraints/test_foreign_key_constraints.py index a7c173ac659..2c02cfe7ad7 100644 --- a/tests/functional/constraints/test_foreign_key_constraints.py +++ b/tests/functional/constraints/test_foreign_key_constraints.py @@ -1,6 +1,7 @@ import pytest -from dbt.exceptions import DbtRuntimeError +from dbt.artifacts.resources import RefArgs +from dbt.exceptions import CompilationError, ParsingError from dbt.tests.util import get_artifact, run_dbt from dbt_common.contracts.constraints import ( ColumnLevelConstraint, @@ -30,15 +31,20 @@ def models(self): def test_model_level_fk_to(self, project, unique_schema): manifest = run_dbt(["parse"]) - assert len(manifest.nodes["model.test.my_model"].constraints) == 1 + node_with_fk_constraint = manifest.nodes["model.test.my_model"] + assert len(node_with_fk_constraint.constraints) == 1 - parsed_constraint = manifest.nodes["model.test.my_model"].constraints[0] + parsed_constraint = node_with_fk_constraint.constraints[0] assert parsed_constraint == ModelLevelConstraint( type=ConstraintType.foreign_key, columns=["id"], to="ref('my_model_to')", to_columns=["id"], ) + # Assert column-level constraint source included in node.depends_on + assert node_with_fk_constraint.refs == [RefArgs("my_model_to")] + assert node_with_fk_constraint.depends_on.nodes == ["model.test.my_model_to"] + assert node_with_fk_constraint.sources == [] # Assert compilation renders to from 'ref' to relation identifer run_dbt(["compile"]) @@ -64,15 +70,20 @@ def models(self): def test_model_level_fk_to(self, project, unique_schema): manifest = run_dbt(["parse"]) - assert len(manifest.nodes["model.test.my_model"].constraints) == 1 + node_with_fk_constraint = manifest.nodes["model.test.my_model"] + assert len(node_with_fk_constraint.constraints) == 1 - parsed_constraint = manifest.nodes["model.test.my_model"].constraints[0] + parsed_constraint = node_with_fk_constraint.constraints[0] assert parsed_constraint == ModelLevelConstraint( type=ConstraintType.foreign_key, columns=["id"], to="source('test_source', 'test_table')", to_columns=["id"], ) + # Assert column-level constraint source included in node.depends_on + assert node_with_fk_constraint.refs == [] + assert node_with_fk_constraint.depends_on.nodes == ["source.test.test_source.test_table"] + assert node_with_fk_constraint.sources == [["test_source", "test_table"]] # Assert compilation renders to from 'ref' to relation identifer run_dbt(["compile"]) @@ -97,8 +108,10 @@ def models(self): } def test_model_level_fk_to_doesnt_exist(self, project): - with pytest.raises(DbtRuntimeError, match="not in the graph"): - run_dbt(["compile"]) + with pytest.raises( + CompilationError, match="depends on a node named 'doesnt_exist' which was not found" + ): + run_dbt(["parse"]) class TestModelLevelForeignKeyConstraintRefSyntaxError: @@ -112,10 +125,10 @@ def models(self): def test_model_level_fk_to(self, project): with pytest.raises( - DbtRuntimeError, - match="'model.test.my_model' defines a foreign key constraint 'to' expression which is not valid 'ref' or 'source' syntax", + ParsingError, + match="Invalid 'ref' or 'source' syntax on foreign key constraint 'to' on model my_model: invalid", ): - run_dbt(["compile"]) + run_dbt(["parse"]) class TestColumnLevelForeignKeyConstraintToRef: @@ -129,12 +142,18 @@ def models(self): def test_column_level_fk_to(self, project, unique_schema): manifest = run_dbt(["parse"]) - assert len(manifest.nodes["model.test.my_model"].columns["id"].constraints) == 1 + node_with_fk_constraint = manifest.nodes["model.test.my_model"] + assert len(node_with_fk_constraint.columns["id"].constraints) == 1 - parsed_constraint = manifest.nodes["model.test.my_model"].columns["id"].constraints[0] + parsed_constraint = node_with_fk_constraint.columns["id"].constraints[0] + # Assert column-level constraint parsed assert parsed_constraint == ColumnLevelConstraint( type=ConstraintType.foreign_key, to="ref('my_model_to')", to_columns=["id"] ) + # Assert column-level constraint ref included in node.depends_on + assert node_with_fk_constraint.refs == [RefArgs(name="my_model_to")] + assert node_with_fk_constraint.sources == [] + assert node_with_fk_constraint.depends_on.nodes == ["model.test.my_model_to"] # Assert compilation renders to from 'ref' to relation identifer run_dbt(["compile"]) @@ -161,14 +180,19 @@ def models(self): def test_model_level_fk_to(self, project, unique_schema): manifest = run_dbt(["parse"]) - assert len(manifest.nodes["model.test.my_model"].columns["id"].constraints) == 1 + node_with_fk_constraint = manifest.nodes["model.test.my_model"] + assert len(node_with_fk_constraint.columns["id"].constraints) == 1 - parsed_constraint = manifest.nodes["model.test.my_model"].columns["id"].constraints[0] + parsed_constraint = node_with_fk_constraint.columns["id"].constraints[0] assert parsed_constraint == ColumnLevelConstraint( type=ConstraintType.foreign_key, to="source('test_source', 'test_table')", to_columns=["id"], ) + # Assert column-level constraint source included in node.depends_on + assert node_with_fk_constraint.refs == [] + assert node_with_fk_constraint.depends_on.nodes == ["source.test.test_source.test_table"] + assert node_with_fk_constraint.sources == [["test_source", "test_table"]] # Assert compilation renders to from 'ref' to relation identifer run_dbt(["compile"]) @@ -194,8 +218,10 @@ def models(self): } def test_model_level_fk_to_doesnt_exist(self, project): - with pytest.raises(DbtRuntimeError, match="not in the graph"): - run_dbt(["compile"]) + with pytest.raises( + CompilationError, match="depends on a node named 'doesnt_exist' which was not found" + ): + run_dbt(["parse"]) class TestColumnLevelForeignKeyConstraintRefSyntaxError: @@ -209,7 +235,7 @@ def models(self): def test_model_level_fk_to(self, project): with pytest.raises( - DbtRuntimeError, - match="'model.test.my_model' defines a foreign key constraint 'to' expression which is not valid 'ref' or 'source' syntax", + ParsingError, + match="Invalid 'ref' or 'source' syntax on foreign key constraint 'to' on model my_model: invalid.", ): - run_dbt(["compile"]) + run_dbt(["parse"]) diff --git a/tests/unit/clients/test_jinja_static.py b/tests/unit/clients/test_jinja_static.py index 52dc1f671f0..c714624300c 100644 --- a/tests/unit/clients/test_jinja_static.py +++ b/tests/unit/clients/test_jinja_static.py @@ -4,6 +4,7 @@ from dbt.clients.jinja_static import ( statically_extract_macro_calls, statically_parse_ref, + statically_parse_ref_or_source, statically_parse_source, ) from dbt.context.base import generate_base_context @@ -88,4 +89,24 @@ def test_invalid_expression(self, invalid_expression): def test_valid_ref_expression(self): parsed_source = statically_parse_source("source('schema', 'table')") - assert parsed_source == ("schema", "table") + assert parsed_source == ["schema", "table"] + + +class TestStaticallyParseRefOrSource: + def test_invalid_expression(self): + with pytest.raises(ParsingError): + statically_parse_ref_or_source("invalid") + + @pytest.mark.parametrize( + "expression,expected_ref_or_source", + [ + ("ref('model')", RefArgs(name="model")), + ("ref('package','model')", RefArgs(name="model", package="package")), + ("ref('model',v=3)", RefArgs(name="model", version=3)), + ("ref('package','model',v=3)", RefArgs(name="model", package="package", version=3)), + ("source('schema', 'table')", ["schema", "table"]), + ], + ) + def test_valid_ref_expression(self, expression, expected_ref_or_source): + ref_or_source = statically_parse_ref_or_source(expression) + assert ref_or_source == expected_ref_or_source diff --git a/tests/unit/graph/test_nodes.py b/tests/unit/graph/test_nodes.py index ff14874eb65..79522d06427 100644 --- a/tests/unit/graph/test_nodes.py +++ b/tests/unit/graph/test_nodes.py @@ -68,6 +68,48 @@ def test_is_past_deprecation_date( assert default_model_node.is_past_deprecation_date is expected_is_past_deprecation_date + @pytest.mark.parametrize( + "model_constraints,columns,expected_all_constraints", + [ + ([], {}, []), + ( + [ModelLevelConstraint(type=ConstraintType.foreign_key)], + {}, + [ModelLevelConstraint(type=ConstraintType.foreign_key)], + ), + ( + [], + { + "id": ColumnInfo( + name="id", + constraints=[ColumnLevelConstraint(type=ConstraintType.foreign_key)], + ) + }, + [ColumnLevelConstraint(type=ConstraintType.foreign_key)], + ), + ( + [ModelLevelConstraint(type=ConstraintType.foreign_key)], + { + "id": ColumnInfo( + name="id", + constraints=[ColumnLevelConstraint(type=ConstraintType.foreign_key)], + ) + }, + [ + ModelLevelConstraint(type=ConstraintType.foreign_key), + ColumnLevelConstraint(type=ConstraintType.foreign_key), + ], + ), + ], + ) + def test_all_constraints( + self, default_model_node, model_constraints, columns, expected_all_constraints + ): + default_model_node.constraints = model_constraints + default_model_node.columns = columns + + assert default_model_node.all_constraints == expected_all_constraints + class TestSemanticModel: @pytest.fixture(scope="function") From eca8c3e718b3dfb0abfe54b84214530d26efe336 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Mon, 22 Jul 2024 18:13:03 -0400 Subject: [PATCH 11/12] restore dev-requirements.txt dbt-adapters --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index c5ddd0217e6..20605e632b8 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,4 @@ -git+https://github.com/dbt-labs/dbt-adapters.git@render-foreign-constraint-ref +git+https://github.com/dbt-labs/dbt-adapters.git@main git+https://github.com/dbt-labs/dbt-adapters.git@main#subdirectory=dbt-tests-adapter git+https://github.com/dbt-labs/dbt-common.git@main git+https://github.com/dbt-labs/dbt-postgres.git@main From dc10493a30b44c60572a06160e95e202622ed02e Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Mon, 22 Jul 2024 18:21:57 -0400 Subject: [PATCH 12/12] simplify jinja_static --- core/dbt/clients/jinja_static.py | 69 +++++++------------------ tests/unit/clients/test_jinja_static.py | 33 ------------ 2 files changed, 19 insertions(+), 83 deletions(-) diff --git a/core/dbt/clients/jinja_static.py b/core/dbt/clients/jinja_static.py index 6082f03f80c..d8746a7607d 100644 --- a/core/dbt/clients/jinja_static.py +++ b/core/dbt/clients/jinja_static.py @@ -157,68 +157,37 @@ def statically_parse_adapter_dispatch(func_call, ctx, db_wrapper): return possible_macro_calls -def statically_parse_ref(input: str) -> RefArgs: +def statically_parse_ref_or_source(expression: str) -> Union[RefArgs, List[str]]: """ - Returns a RefArgs object corresponding to an input jinja expression. + Returns a RefArgs or List[str] object, corresponding to ref or source respectively, given an input jinja expression. input: str representing how input node is referenced in tested model sql * examples: - "ref('my_model_a')" - "ref('my_model_a', version=3)" - "ref('package', 'my_model_a', version=3)" - - If input is not a well-formed jinja ref expression, a ParsingError is raised. - """ - try: - statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}") - except ExtractionError: - raise ParsingError(f"Invalid jinja expression: {input}") - - if not statically_parsed.get("refs"): - raise ParsingError(f"Invalid ref expression: {input}") - - ref = list(statically_parsed["refs"])[0] - return RefArgs(package=ref.get("package"), name=ref.get("name"), version=ref.get("version")) - - -def statically_parse_source(input: str) -> List[str]: - """ - Returns a RefArgs object corresponding to an input jinja expression. - - input: str representing how input node is referenced in tested model sql - * examples: - "source('my_source_schema', 'my_source_name')" - If input is not a well-formed jinja source expression, ParsingError is raised. + If input is not a well-formed jinja ref or source expression, a ParsingError is raised. """ - try: - statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}") - except ExtractionError: - raise ParsingError(f"Invalid jinja expression: {input}") - - if not statically_parsed.get("sources"): - raise ParsingError(f"Invalid source expression: {input}") - - source = list(statically_parsed["sources"])[0] - source_name, source_table_name = source - return [source_name, source_table_name] - - -def statically_parse_ref_or_source(expression: str) -> Union[RefArgs, List[str]]: ref_or_source: Union[RefArgs, List[str]] - valid_ref = True - valid_source = True try: - ref_or_source = statically_parse_ref(expression) - except ParsingError: - valid_ref = False - try: - ref_or_source = statically_parse_source(expression) - except ParsingError: - valid_source = False - - if not valid_ref and not valid_source: - raise ParsingError(f"Invalid ref or source syntax: {expression}.") + statically_parsed = py_extract_from_source(f"{{{{ {expression} }}}}") + except ExtractionError: + raise ParsingError(f"Invalid jinja expression: {expression}") + + if statically_parsed.get("refs"): + raw_ref = list(statically_parsed["refs"])[0] + ref_or_source = RefArgs( + package=raw_ref.get("package"), + name=raw_ref.get("name"), + version=raw_ref.get("version"), + ) + elif statically_parsed.get("sources"): + source_name, source_table_name = list(statically_parsed["sources"])[0] + ref_or_source = [source_name, source_table_name] + else: + raise ParsingError(f"Invalid ref or source expression: {expression}") return ref_or_source diff --git a/tests/unit/clients/test_jinja_static.py b/tests/unit/clients/test_jinja_static.py index c714624300c..171976a6b50 100644 --- a/tests/unit/clients/test_jinja_static.py +++ b/tests/unit/clients/test_jinja_static.py @@ -3,9 +3,7 @@ from dbt.artifacts.resources import RefArgs from dbt.clients.jinja_static import ( statically_extract_macro_calls, - statically_parse_ref, statically_parse_ref_or_source, - statically_parse_source, ) from dbt.context.base import generate_base_context from dbt.exceptions import ParsingError @@ -61,37 +59,6 @@ def test_extract_macro_calls(macro_string, expected_possible_macro_calls): assert possible_macro_calls == expected_possible_macro_calls -class TestStaticallyParseRef: - @pytest.mark.parametrize("invalid_expression", ["invalid", "source('schema', 'table')"]) - def test_invalid_expression(self, invalid_expression): - with pytest.raises(ParsingError): - statically_parse_ref(invalid_expression) - - @pytest.mark.parametrize( - "ref_expression,expected_ref_args", - [ - ("ref('model')", RefArgs(name="model")), - ("ref('package','model')", RefArgs(name="model", package="package")), - ("ref('model',v=3)", RefArgs(name="model", version=3)), - ("ref('package','model',v=3)", RefArgs(name="model", package="package", version=3)), - ], - ) - def test_valid_ref_expression(self, ref_expression, expected_ref_args): - ref_args = statically_parse_ref(ref_expression) - assert ref_args == expected_ref_args - - -class TestStaticallyParseSource: - @pytest.mark.parametrize("invalid_expression", ["invalid", "ref('package', 'model')"]) - def test_invalid_expression(self, invalid_expression): - with pytest.raises(ParsingError): - statically_parse_source(invalid_expression) - - def test_valid_ref_expression(self): - parsed_source = statically_parse_source("source('schema', 'table')") - assert parsed_source == ["schema", "table"] - - class TestStaticallyParseRefOrSource: def test_invalid_expression(self): with pytest.raises(ParsingError):