From 7f777f8a4290ad1cd0552d5207ad09e2da1a84bb Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Tue, 5 Dec 2023 12:03:16 +0900 Subject: [PATCH 1/8] move relation contract to dbt.adapters --- core/dbt/adapters/base/relation.py | 2 +- core/dbt/{ => adapters}/contracts/relation.py | 5 ++--- core/dbt/adapters/protocol.py | 2 +- core/dbt/common/exceptions/__init__.py | 1 + core/dbt/common/exceptions/contracts.py | 17 +++++++++++++++++ core/dbt/config/runtime.py | 6 +++--- core/dbt/context/exceptions_jinja.py | 8 ++++++-- core/dbt/contracts/util.py | 6 +----- core/dbt/exceptions.py | 15 --------------- .../postgres/dbt/adapters/postgres/relation.py | 2 +- .../tests/adapter/materialized_view/basic.py | 2 +- .../tests/adapter/materialized_view/changes.py | 2 +- tests/unit/test_postgres_adapter.py | 2 +- tests/unit/test_relation.py | 2 +- 14 files changed, 37 insertions(+), 35 deletions(-) rename core/dbt/{ => adapters}/contracts/relation.py (96%) create mode 100644 core/dbt/common/exceptions/contracts.py diff --git a/core/dbt/adapters/base/relation.py b/core/dbt/adapters/base/relation.py index 040d5e94442..70a01398f0d 100644 --- a/core/dbt/adapters/base/relation.py +++ b/core/dbt/adapters/base/relation.py @@ -3,7 +3,7 @@ from typing import Optional, TypeVar, Any, Type, Dict, Iterator, Tuple, Set, Union, FrozenSet from dbt.contracts.graph.nodes import SourceDefinition, ManifestNode, ResultNode, ParsedNode -from dbt.contracts.relation import ( +from dbt.adapters.contracts.relation import ( RelationType, ComponentName, HasQuoting, diff --git a/core/dbt/contracts/relation.py b/core/dbt/adapters/contracts/relation.py similarity index 96% rename from core/dbt/contracts/relation.py rename to core/dbt/adapters/contracts/relation.py index 62a62d814ee..c4cead46e45 100644 --- a/core/dbt/contracts/relation.py +++ b/core/dbt/adapters/contracts/relation.py @@ -8,9 +8,8 @@ from dbt.common.dataclass_schema import dbtClassMixin, StrEnum -from dbt.contracts.util import Replaceable -from dbt.common.exceptions import CompilationError -from dbt.exceptions import DataclassNotDictError +from dbt.common.contracts.util import Replaceable +from dbt.common.exceptions import CompilationError, DataclassNotDictError from dbt.common.utils import deep_merge diff --git a/core/dbt/adapters/protocol.py b/core/dbt/adapters/protocol.py index 04c1b2fd0c5..45d86bcc307 100644 --- a/core/dbt/adapters/protocol.py +++ b/core/dbt/adapters/protocol.py @@ -14,10 +14,10 @@ import agate from dbt.adapters.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse +from dbt.adapters.contracts.relation import Policy, HasQuoting from dbt.contracts.graph.nodes import ResultNode from dbt.contracts.graph.model_config import BaseConfig from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.relation import Policy, HasQuoting @dataclass diff --git a/core/dbt/common/exceptions/__init__.py b/core/dbt/common/exceptions/__init__.py index 208ba24dbf7..da49548aa9d 100644 --- a/core/dbt/common/exceptions/__init__.py +++ b/core/dbt/common/exceptions/__init__.py @@ -1,3 +1,4 @@ from dbt.common.exceptions.base import * # noqa from dbt.common.exceptions.events import * # noqa from dbt.common.exceptions.macros import * # noqa +from dbt.common.exceptions.contracts import * # noqa diff --git a/core/dbt/common/exceptions/contracts.py b/core/dbt/common/exceptions/contracts.py new file mode 100644 index 00000000000..6b32793dea5 --- /dev/null +++ b/core/dbt/common/exceptions/contracts.py @@ -0,0 +1,17 @@ +from typing import Any +from dbt.common.exceptions import CompilationError + + +# this is part of the context and also raised in dbt.contracts.relation.py +class DataclassNotDictError(CompilationError): + def __init__(self, obj: Any): + self.obj = obj + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + msg = ( + f'The object ("{self.obj}") was used as a dictionary. This ' + "capability has been removed from objects of this type." + ) + + return msg diff --git a/core/dbt/config/runtime.py b/core/dbt/config/runtime.py index 2a66f2f31b5..76b10e009da 100644 --- a/core/dbt/config/runtime.py +++ b/core/dbt/config/runtime.py @@ -15,13 +15,13 @@ Type, ) -from dbt.flags import get_flags from dbt.adapters.factory import get_include_paths, get_relation_class_by_name -from dbt.config.project import load_raw_project from dbt.adapters.contracts.connection import AdapterRequiredConfig, Credentials, HasCredentials +from dbt.adapters.contracts.relation import ComponentName +from dbt.flags import get_flags +from dbt.config.project import load_raw_project from dbt.contracts.graph.manifest import ManifestMetadata from dbt.contracts.project import Configuration, UserConfig -from dbt.contracts.relation import ComponentName from dbt.common.dataclass_schema import ValidationError from dbt.common.events.functions import warn_or_error from dbt.common.events.types import UnusedResourceConfigPath diff --git a/core/dbt/context/exceptions_jinja.py b/core/dbt/context/exceptions_jinja.py index 87d7977982f..87e79fea3ad 100644 --- a/core/dbt/context/exceptions_jinja.py +++ b/core/dbt/context/exceptions_jinja.py @@ -4,7 +4,12 @@ from dbt.common.events.functions import warn_or_error from dbt.common.events.types import JinjaLogWarning -from dbt.common.exceptions import DbtRuntimeError, NotImplementedError, DbtDatabaseError +from dbt.common.exceptions import ( + DbtRuntimeError, + NotImplementedError, + DbtDatabaseError, + DataclassNotDictError, +) from dbt.adapters.exceptions import ( MissingConfigError, ColumnTypeMissingError, @@ -15,7 +20,6 @@ MissingRelationError, AmbiguousAliasError, AmbiguousCatalogMatchError, - DataclassNotDictError, CompilationError, DependencyNotFoundError, DependencyError, diff --git a/core/dbt/contracts/util.py b/core/dbt/contracts/util.py index c22efbd1220..7c33c18dca1 100644 --- a/core/dbt/contracts/util.py +++ b/core/dbt/contracts/util.py @@ -10,6 +10,7 @@ ) from dbt.version import __version__ +from dbt.common.contracts.util import Replaceable from dbt.common.events.functions import get_metadata_vars from dbt.common.invocation import get_invocation_id from dbt.common.dataclass_schema import dbtClassMixin @@ -41,11 +42,6 @@ class Foo: return [] -class Replaceable: - def replace(self, **kwargs): - return dataclasses.replace(self, **kwargs) - - class Mergeable(Replaceable): def merged(self, *args): """Perform a shallow merge, where the last non-None write wins. This is diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index e3ad3a48163..d08b03a257b 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -1363,21 +1363,6 @@ def get_message(self) -> str: return msg -# this is part of the context and also raised in dbt.contracts.relation.py -class DataclassNotDictError(CompilationError): - def __init__(self, obj: Any): - self.obj = obj - super().__init__(msg=self.get_message()) - - def get_message(self) -> str: - msg = ( - f'The object ("{self.obj}") was used as a dictionary. This ' - "capability has been removed from objects of this type." - ) - - return msg - - class DependencyNotFoundError(CompilationError): def __init__(self, node, node_description, required_pkg): self.node = node diff --git a/plugins/postgres/dbt/adapters/postgres/relation.py b/plugins/postgres/dbt/adapters/postgres/relation.py index fbb358cde43..5c76052ebf4 100644 --- a/plugins/postgres/dbt/adapters/postgres/relation.py +++ b/plugins/postgres/dbt/adapters/postgres/relation.py @@ -7,7 +7,7 @@ RelationResults, ) from dbt.context.providers import RuntimeConfigObject -from dbt.contracts.relation import RelationType +from dbt.adapters.contracts.relation import RelationType from dbt.common.exceptions import DbtRuntimeError from dbt.adapters.postgres.relation_configs import ( diff --git a/tests/adapter/dbt/tests/adapter/materialized_view/basic.py b/tests/adapter/dbt/tests/adapter/materialized_view/basic.py index ec90d503650..9720945ba50 100644 --- a/tests/adapter/dbt/tests/adapter/materialized_view/basic.py +++ b/tests/adapter/dbt/tests/adapter/materialized_view/basic.py @@ -3,7 +3,7 @@ import pytest from dbt.adapters.base.relation import BaseRelation -from dbt.contracts.relation import RelationType +from dbt.adapters.contracts.relation import RelationType from dbt.tests.util import ( assert_message_in_logs, get_model_file, diff --git a/tests/adapter/dbt/tests/adapter/materialized_view/changes.py b/tests/adapter/dbt/tests/adapter/materialized_view/changes.py index 5fc933fbe0d..b31149a5ac2 100644 --- a/tests/adapter/dbt/tests/adapter/materialized_view/changes.py +++ b/tests/adapter/dbt/tests/adapter/materialized_view/changes.py @@ -4,7 +4,7 @@ from dbt.adapters.base.relation import BaseRelation from dbt.contracts.graph.model_config import OnConfigurationChangeOption -from dbt.contracts.relation import RelationType +from dbt.adapters.contracts.relation import RelationType from dbt.tests.util import ( assert_message_in_logs, get_model_file, diff --git a/tests/unit/test_postgres_adapter.py b/tests/unit/test_postgres_adapter.py index 8739e3e2784..1dfb7a9146a 100644 --- a/tests/unit/test_postgres_adapter.py +++ b/tests/unit/test_postgres_adapter.py @@ -7,7 +7,7 @@ from unittest import mock from dbt.adapters.base import BaseRelation -from dbt.contracts.relation import Path +from dbt.adapters.contracts.relation import Path from dbt.task.debug import DebugTask from dbt.adapters.base.query_headers import MacroQueryStringSetter diff --git a/tests/unit/test_relation.py b/tests/unit/test_relation.py index 94995958ba6..5a5c58de7a5 100644 --- a/tests/unit/test_relation.py +++ b/tests/unit/test_relation.py @@ -3,7 +3,7 @@ import pytest from dbt.adapters.base import BaseRelation -from dbt.contracts.relation import RelationType +from dbt.adapters.contracts.relation import RelationType @pytest.mark.parametrize( From 2cee8652a661a9af41f1b7d4cfd27f712ec21226 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Tue, 5 Dec 2023 12:06:06 +0900 Subject: [PATCH 2/8] changelog entry --- .changes/unreleased/Under the Hood-20231205-120559.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Under the Hood-20231205-120559.yaml diff --git a/.changes/unreleased/Under the Hood-20231205-120559.yaml b/.changes/unreleased/Under the Hood-20231205-120559.yaml new file mode 100644 index 00000000000..a209bda9f6c --- /dev/null +++ b/.changes/unreleased/Under the Hood-20231205-120559.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Remove usage of dbt.contracts in dbt/adapters +time: 2023-12-05T12:05:59.936775+09:00 +custom: + Author: michelleark + Issue: "9208" From 160d0db2383e5aea7d46681487c591033b6e5202 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Tue, 5 Dec 2023 15:03:16 +0900 Subject: [PATCH 3/8] first pass: clean up relation.create_from --- core/dbt/adapters/base/relation.py | 75 +++++++------------------ core/dbt/adapters/contracts/relation.py | 8 +++ core/dbt/adapters/protocol.py | 18 ++---- core/dbt/context/providers.py | 11 +--- core/dbt/contracts/graph/nodes.py | 7 +++ core/dbt/parser/manifest.py | 2 +- core/dbt/task/clone.py | 2 +- core/dbt/task/freshness.py | 2 +- 8 files changed, 46 insertions(+), 79 deletions(-) diff --git a/core/dbt/adapters/base/relation.py b/core/dbt/adapters/base/relation.py index 70a01398f0d..33d5011e1cb 100644 --- a/core/dbt/adapters/base/relation.py +++ b/core/dbt/adapters/base/relation.py @@ -2,8 +2,8 @@ from dataclasses import dataclass, field from typing import Optional, TypeVar, Any, Type, Dict, Iterator, Tuple, Set, Union, FrozenSet -from dbt.contracts.graph.nodes import SourceDefinition, ManifestNode, ResultNode, ParsedNode from dbt.adapters.contracts.relation import ( + RelationConfig, RelationType, ComponentName, HasQuoting, @@ -11,9 +11,7 @@ Policy, Path, ) -from dbt.common.exceptions import DbtInternalError from dbt.adapters.exceptions import MultipleDatabasesNotAllowedError, ApproximateMatchError -from dbt.node_types import NodeType from dbt.common.utils import filter_null_values, deep_merge from dbt.adapters.utils import classproperty @@ -198,33 +196,14 @@ def quoted(self, identifier): identifier=identifier, ) - @classmethod - def create_from_source(cls: Type[Self], source: SourceDefinition, **kwargs: Any) -> Self: - source_quoting = source.quoting.to_dict(omit_none=True) - source_quoting.pop("column", None) - quote_policy = deep_merge( - cls.get_default_quote_policy().to_dict(omit_none=True), - source_quoting, - kwargs.get("quote_policy", {}), - ) - - return cls.create( - database=source.database, - schema=source.schema, - identifier=source.identifier, - quote_policy=quote_policy, - **kwargs, - ) - @staticmethod def add_ephemeral_prefix(name: str): return f"__dbt__cte__{name}" @classmethod - def create_ephemeral_from_node( + def create_ephemeral_from( cls: Type[Self], - config: HasQuoting, - node: ManifestNode, + node: RelationConfig, ) -> Self: # Note that ephemeral models are based on the name. identifier = cls.add_ephemeral_prefix(node.name) @@ -234,47 +213,33 @@ def create_ephemeral_from_node( ).quote(identifier=False) @classmethod - def create_from_node( + def create_from( cls: Type[Self], - config: HasQuoting, - node, - quote_policy: Optional[Dict[str, bool]] = None, + quoting: HasQuoting, + config: RelationConfig, **kwargs: Any, ) -> Self: - if quote_policy is None: - quote_policy = {} + quote_policy = kwargs.pop("quote_policy", {}) + + config_quoting = config.quoting_dict + config_quoting.pop("column", None) - quote_policy = dbt.common.utils.merge(config.quoting, quote_policy) + # precedence: kwargs quoting > config quoting > base quoting > default quoting + quote_policy = deep_merge( + cls.get_default_quote_policy().to_dict(omit_none=True), + quoting.quoting, + config_quoting, + quote_policy, + ) return cls.create( - database=node.database, - schema=node.schema, - identifier=node.alias, + database=config.database, + schema=config.schema, + identifier=config.identifier, quote_policy=quote_policy, **kwargs, ) - @classmethod - def create_from( - cls: Type[Self], - config: HasQuoting, - node: ResultNode, - **kwargs: Any, - ) -> Self: - if node.resource_type == NodeType.Source: - if not isinstance(node, SourceDefinition): - raise DbtInternalError( - "type mismatch, expected SourceDefinition but got {}".format(type(node)) - ) - return cls.create_from_source(node, **kwargs) - else: - # Can't use ManifestNode here because of parameterized generics - if not isinstance(node, (ParsedNode)): - raise DbtInternalError( - f"type mismatch, expected ManifestNode but got {type(node)}" - ) - return cls.create_from_node(config, node, **kwargs) - @classmethod def create( cls: Type[Self], diff --git a/core/dbt/adapters/contracts/relation.py b/core/dbt/adapters/contracts/relation.py index c4cead46e45..98e9d8ef878 100644 --- a/core/dbt/adapters/contracts/relation.py +++ b/core/dbt/adapters/contracts/relation.py @@ -22,6 +22,14 @@ class RelationType(StrEnum): Ephemeral = "ephemeral" +class RelationConfig(Protocol): + name: str + database: str + schema: str + identifier: str + quoting_dict: Dict[str, bool] + + class ComponentName(StrEnum): Database = "database" Schema = "schema" diff --git a/core/dbt/adapters/protocol.py b/core/dbt/adapters/protocol.py index 45d86bcc307..270a0c4c969 100644 --- a/core/dbt/adapters/protocol.py +++ b/core/dbt/adapters/protocol.py @@ -1,21 +1,11 @@ from dataclasses import dataclass -from typing import ( - Type, - Hashable, - Optional, - ContextManager, - List, - Generic, - TypeVar, - Tuple, -) +from typing import Type, Hashable, Optional, ContextManager, List, Generic, TypeVar, Tuple, Any from typing_extensions import Protocol import agate from dbt.adapters.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse -from dbt.adapters.contracts.relation import Policy, HasQuoting -from dbt.contracts.graph.nodes import ResultNode +from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig from dbt.contracts.graph.model_config import BaseConfig from dbt.contracts.graph.manifest import Manifest @@ -42,7 +32,9 @@ def get_default_quote_policy(cls) -> Policy: ... @classmethod - def create_from(cls: Type[Self], config: HasQuoting, node: ResultNode) -> Self: + def create_from( + cls: Type[Self], quoting: HasQuoting, config: RelationConfig, **kwargs: Any + ) -> Self: ... diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index e9094d4518c..2ecb5a0b84a 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -89,11 +89,6 @@ def __init__(self, adapter): def __getattr__(self, key): return getattr(self._relation_type, key) - def create_from_source(self, *args, **kwargs): - # bypass our create when creating from source so as not to mess up - # the source quoting - return self._relation_type.create_from_source(*args, **kwargs) - def create(self, *args, **kwargs): kwargs["quote_policy"] = merge(self._quoting_config, kwargs.pop("quote_policy", {})) return self._relation_type.create(*args, **kwargs) @@ -529,7 +524,7 @@ def resolve( def create_relation(self, target_model: ManifestNode) -> RelationProxy: if target_model.is_ephemeral_model: self.model.set_cte(target_model.unique_id, None) - return self.Relation.create_ephemeral_from_node(self.config, target_model) + return self.Relation.create_ephemeral_from(target_model) else: return self.Relation.create_from(self.config, target_model) @@ -588,7 +583,7 @@ def resolve(self, source_name: str, table_name: str): target_kind="source", disabled=(isinstance(target_source, Disabled)), ) - return self.Relation.create_from_source(target_source) + return self.Relation.create_from(self.config, target_source) # metric` implementations @@ -1475,7 +1470,7 @@ def defer_relation(self) -> Optional[RelationProxy]: object for that stateful other """ if getattr(self.model, "defer_relation", None): - return self.db_wrapper.Relation.create_from_node( + return self.db_wrapper.Relation.create_from( self.config, self.model.defer_relation # type: ignore ) else: diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index fad8a427016..1ab8b9b3e84 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -219,6 +219,13 @@ def __pre_deserialize__(cls, data): data["database"] = None return data + @property + def quoting_dict(self) -> Dict[str, bool]: + if hasattr(self, "quoting"): + return self.quoting.to_dict(omit_none=True) + else: + return {} + @dataclass class MacroDependsOn(dbtClassMixin, Replaceable): diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index d1dabeaf213..8bf76f0c45a 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -1357,7 +1357,7 @@ def _check_resource_uniqueness( # the full node name is really defined by the adapter's relation relation_cls = get_relation_class_by_name(config.credentials.type) - relation = relation_cls.create_from(config=config, node=node) + relation = relation_cls.create_from(quoting=config, config=node) full_node_name = str(relation) existing_alias = alias_resources.get(full_node_name) diff --git a/core/dbt/task/clone.py b/core/dbt/task/clone.py index 089bc7be265..7a782682f65 100644 --- a/core/dbt/task/clone.py +++ b/core/dbt/task/clone.py @@ -108,7 +108,7 @@ def get_model_schemas(self, adapter, selected_uids: Iterable[str]) -> Set[BaseRe # cache the 'other' schemas too! if node.defer_relation: # type: ignore - other_relation = adapter.Relation.create_from_node( + other_relation = adapter.Relation.create_from( self.config, node.defer_relation # type: ignore ) result.add(other_relation.without_identifier()) diff --git a/core/dbt/task/freshness.py b/core/dbt/task/freshness.py index 3e71345db6c..73a1c70739a 100644 --- a/core/dbt/task/freshness.py +++ b/core/dbt/task/freshness.py @@ -99,7 +99,7 @@ def from_run_result(self, result, start_time, timing_info): return result def execute(self, compiled_node, manifest): - relation = self.adapter.Relation.create_from_source(compiled_node) + relation = self.adapter.Relation.create_from(self.config, compiled_node) # given a Source, calculate its freshness. with self.adapter.connection_for(compiled_node): self.adapter.clear_transaction() From d7d5e2335c2eb6e86d7468694242b1b8e80dece9 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Tue, 5 Dec 2023 15:41:13 +0900 Subject: [PATCH 4/8] type ignores --- core/dbt/adapters/base/impl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index e4806fdc8b7..80b3bea6bc2 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -429,7 +429,7 @@ def _get_cache_schemas(self, manifest: Manifest) -> Set[BaseRelation]: """ # the cache only cares about executable nodes return { - self.Relation.create_from(self.config, node).without_identifier() + self.Relation.create_from(self.config, node).without_identifier() # type: ignore[arg-type] for node in manifest.nodes.values() if (node.is_relational and not node.is_ephemeral_model and not node.is_external_node) } @@ -476,7 +476,7 @@ def _get_catalog_relations(self, manifest: Manifest) -> List[BaseRelation]: manifest.sources.values(), ) - relations = [self.Relation.create_from(self.config, n) for n in nodes] + relations = [self.Relation.create_from(self.config, n) for n in nodes] # type: ignore[arg-type] return relations def _relations_cache_for_schemas( From b8de881ed3ae86b5c18a7a6908b2b2c8c1195b90 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Tue, 5 Dec 2023 17:06:49 +0900 Subject: [PATCH 5/8] type ignore --- core/dbt/parser/manifest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 8bf76f0c45a..6aaf4482e6e 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -1357,7 +1357,7 @@ def _check_resource_uniqueness( # the full node name is really defined by the adapter's relation relation_cls = get_relation_class_by_name(config.credentials.type) - relation = relation_cls.create_from(quoting=config, config=node) + relation = relation_cls.create_from(quoting=config, config=node) # type: ignore[arg-type] full_node_name = str(relation) existing_alias = alias_resources.get(full_node_name) From ba53f053fdc1bfe87cd6cc1e331063d4410417b0 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Tue, 5 Dec 2023 17:07:32 +0900 Subject: [PATCH 6/8] changelog entry --- .changes/unreleased/Under the Hood-20231205-170725.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Under the Hood-20231205-170725.yaml diff --git a/.changes/unreleased/Under the Hood-20231205-170725.yaml b/.changes/unreleased/Under the Hood-20231205-170725.yaml new file mode 100644 index 00000000000..2018825bcff --- /dev/null +++ b/.changes/unreleased/Under the Hood-20231205-170725.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Introduce RelationConfig Protocol, consolidate Relation.create_from +time: 2023-12-05T17:07:25.33861+09:00 +custom: + Author: michelleark + Issue: "9215" From 7ad6aa18dad1540282c716164be8bb56e3979e40 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Wed, 6 Dec 2023 11:02:39 +0900 Subject: [PATCH 7/8] update RelationConfig variable names --- core/dbt/adapters/base/relation.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/core/dbt/adapters/base/relation.py b/core/dbt/adapters/base/relation.py index 33d5011e1cb..af508f438e5 100644 --- a/core/dbt/adapters/base/relation.py +++ b/core/dbt/adapters/base/relation.py @@ -203,10 +203,10 @@ def add_ephemeral_prefix(name: str): @classmethod def create_ephemeral_from( cls: Type[Self], - node: RelationConfig, + relation_config: RelationConfig, ) -> Self: # Note that ephemeral models are based on the name. - identifier = cls.add_ephemeral_prefix(node.name) + identifier = cls.add_ephemeral_prefix(relation_config.name) return cls.create( type=cls.CTE, identifier=identifier, @@ -216,15 +216,15 @@ def create_ephemeral_from( def create_from( cls: Type[Self], quoting: HasQuoting, - config: RelationConfig, + relation_config: RelationConfig, **kwargs: Any, ) -> Self: quote_policy = kwargs.pop("quote_policy", {}) - config_quoting = config.quoting_dict + config_quoting = relation_config.quoting_dict config_quoting.pop("column", None) - # precedence: kwargs quoting > config quoting > base quoting > default quoting + # precedence: kwargs quoting > relation config quoting > base quoting > default quoting quote_policy = deep_merge( cls.get_default_quote_policy().to_dict(omit_none=True), quoting.quoting, @@ -233,9 +233,9 @@ def create_from( ) return cls.create( - database=config.database, - schema=config.schema, - identifier=config.identifier, + database=relation_config.database, + schema=relation_config.schema, + identifier=relation_config.identifier, quote_policy=quote_policy, **kwargs, ) From f68af070f3f4660693268b6a9e7223ff2e988e3f Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Wed, 6 Dec 2023 23:27:42 +0900 Subject: [PATCH 8/8] remove manifest from adapter.execute_macro, replace with MacroResolver + remove lazy loading --- core/dbt/adapters/base/impl.py | 77 ++++++++++----------------- core/dbt/adapters/contracts/macros.py | 11 ++++ core/dbt/adapters/protocol.py | 10 ++++ core/dbt/context/providers.py | 5 +- core/dbt/parser/manifest.py | 4 +- core/dbt/task/run_operation.py | 2 +- core/dbt/task/show.py | 2 +- core/dbt/tests/fixtures/project.py | 17 +++++- tests/unit/test_postgres_adapter.py | 4 +- 9 files changed, 73 insertions(+), 59 deletions(-) create mode 100644 core/dbt/adapters/contracts/macros.py diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index dffd4f65f2f..9d06dcdc674 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -9,7 +9,6 @@ Any, Callable, Dict, - Iterable, Iterator, List, Mapping, @@ -28,6 +27,7 @@ ConstraintType, ModelLevelConstraint, ) +from dbt.adapters.contracts.macros import MacroResolver import agate import pytz @@ -62,7 +62,7 @@ Integer, ) from dbt.common.clients.jinja import CallableMacroGenerator -from dbt.contracts.graph.manifest import Manifest, MacroManifest +from dbt.contracts.graph.manifest import Manifest from dbt.common.events.functions import fire_event, warn_or_error from dbt.adapters.events.types import ( CacheMiss, @@ -254,7 +254,20 @@ def __init__(self, config, mp_context: SpawnContext) -> None: self.config = config self.cache = RelationsCache(log_cache_events=config.log_cache_events) self.connections = self.ConnectionManager(config, mp_context) - self._macro_manifest_lazy: Optional[MacroManifest] = None + self._macro_resolver: Optional[MacroResolver] = None + + ### + # Methods to set / access a macro resolver + ### + def set_macro_resolver(self, macro_resolver: MacroResolver) -> None: + self._macro_resolver = macro_resolver + + def get_macro_resolver(self) -> Optional[MacroResolver]: + return self._macro_resolver + + def clear_macro_resolver(self) -> None: + if self._macro_resolver is not None: + self._macro_resolver = None ### # Methods that pass through to the connection manager @@ -367,39 +380,6 @@ def type(cls) -> str: """ return cls.ConnectionManager.TYPE - @property - def _macro_manifest(self) -> MacroManifest: - if self._macro_manifest_lazy is None: - return self.load_macro_manifest() - return self._macro_manifest_lazy - - def check_macro_manifest(self) -> Optional[MacroManifest]: - """Return the internal manifest (used for executing macros) if it's - been initialized, otherwise return None. - """ - return self._macro_manifest_lazy - - def load_macro_manifest(self, base_macros_only=False) -> MacroManifest: - # base_macros_only is for the test framework - if self._macro_manifest_lazy is None: - # avoid a circular import - from dbt.parser.manifest import ManifestLoader - - manifest = ManifestLoader.load_macros( - self.config, - self.connections.set_query_header, - base_macros_only=base_macros_only, - ) - # TODO CT-211 - self._macro_manifest_lazy = manifest # type: ignore[assignment] - # TODO CT-211 - return self._macro_manifest_lazy # type: ignore[return-value] - - def clear_macro_manifest(self): - if self._macro_manifest_lazy is not None: - self._macro_manifest_lazy = None - - ### # Caching methods ### def _schema_is_cached(self, database: Optional[str], schema: str) -> bool: @@ -1054,11 +1034,10 @@ def convert_agate_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[ def execute_macro( self, macro_name: str, - manifest: Optional[Manifest] = None, + macro_resolver: Optional[MacroResolver] = None, project: Optional[str] = None, context_override: Optional[Dict[str, Any]] = None, kwargs: Optional[Dict[str, Any]] = None, - text_only_columns: Optional[Iterable[str]] = None, ) -> AttrDict: """Look macro_name up in the manifest and execute its results. @@ -1078,13 +1057,11 @@ def execute_macro( if context_override is None: context_override = {} - if manifest is None: - # TODO CT-211 - manifest = self._macro_manifest # type: ignore[assignment] - # TODO CT-211 - macro = manifest.find_macro_by_name( # type: ignore[union-attr] - macro_name, self.config.project_name, project - ) + resolver = macro_resolver or self._macro_resolver + if resolver is None: + raise DbtInternalError("macro resolver was None when calling execute_macro!") + + macro = resolver.find_macro_by_name(macro_name, self.config.project_name, project) if macro is None: if project is None: package_name = "any package" @@ -1104,7 +1081,7 @@ def execute_macro( # TODO CT-211 macro=macro, config=self.config, - manifest=manifest, # type: ignore[arg-type] + manifest=resolver, # type: ignore[arg-type] package_name=project, ) macro_context.update(context_override) @@ -1140,7 +1117,7 @@ def _get_one_catalog( kwargs=kwargs, # pass in the full manifest so we get any local project # overrides - manifest=manifest, + macro_resolver=manifest, ) results = self._catalog_filter_table(table, manifest) # type: ignore[arg-type] @@ -1162,7 +1139,7 @@ def _get_one_catalog_by_relations( kwargs=kwargs, # pass in the full manifest, so we get any local project # overrides - manifest=manifest, + macro_resolver=manifest, ) results = self._catalog_filter_table(table, manifest) # type: ignore[arg-type] @@ -1273,7 +1250,7 @@ def calculate_freshness( AttrDict, # current: contains AdapterResponse + agate.Table agate.Table, # previous: just table ] - result = self.execute_macro(FRESHNESS_MACRO_NAME, kwargs=kwargs, manifest=manifest) + result = self.execute_macro(FRESHNESS_MACRO_NAME, kwargs=kwargs, macro_resolver=manifest) if isinstance(result, agate.Table): warn_or_error(CollectFreshnessReturnSignature()) adapter_response = None @@ -1310,7 +1287,7 @@ def calculate_freshness_from_metadata( "relations": [source], } result = self.execute_macro( - GET_RELATION_LAST_MODIFIED_MACRO_NAME, kwargs=kwargs, manifest=manifest + GET_RELATION_LAST_MODIFIED_MACRO_NAME, kwargs=kwargs, macro_resolver=manifest ) adapter_response, table = result.response, result.table # type: ignore[attr-defined] diff --git a/core/dbt/adapters/contracts/macros.py b/core/dbt/adapters/contracts/macros.py new file mode 100644 index 00000000000..5011a337a39 --- /dev/null +++ b/core/dbt/adapters/contracts/macros.py @@ -0,0 +1,11 @@ +from typing import Optional +from typing_extensions import Protocol + +from dbt.common.clients.jinja import MacroProtocol + + +class MacroResolver(Protocol): + def find_macro_by_name( + self, name: str, root_project_name: str, package: Optional[str] + ) -> Optional[MacroProtocol]: + raise NotImplementedError("find_macro_by_name not implemented") diff --git a/core/dbt/adapters/protocol.py b/core/dbt/adapters/protocol.py index b182878ae80..d83b8b582b3 100644 --- a/core/dbt/adapters/protocol.py +++ b/core/dbt/adapters/protocol.py @@ -5,6 +5,7 @@ import agate from dbt.adapters.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse +from dbt.adapters.contracts.macros import MacroResolver from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig from dbt.contracts.graph.model_config import BaseConfig from dbt.contracts.graph.manifest import Manifest @@ -66,6 +67,15 @@ class AdapterProtocol( # type: ignore[misc] def __init__(self, config: AdapterRequiredConfig) -> None: ... + def set_macro_resolver(self, macro_resolver: MacroResolver) -> None: + ... + + def get_macro_resolver(self) -> Optional[MacroResolver]: + ... + + def clear_macro_resolver(self) -> None: + ... + @classmethod def type(cls) -> str: pass diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 2ecb5a0b84a..22f8460cd1e 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -15,6 +15,7 @@ from typing_extensions import Protocol from dbt.adapters.base.column import Column +from dbt.common.clients.jinja import MacroProtocol from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names from dbt.common.clients import agate_helper from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack @@ -1355,7 +1356,7 @@ class MacroContext(ProviderContext): def __init__( self, - model: Macro, + model: MacroProtocol, config: RuntimeConfig, manifest: Manifest, provider: Provider, @@ -1512,7 +1513,7 @@ def generate_runtime_model_context( def generate_runtime_macro_context( - macro: Macro, + macro: MacroProtocol, config: RuntimeConfig, manifest: Manifest, package_name: Optional[str], diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index e973b5f3592..c952db063f4 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -286,7 +286,7 @@ def get_full_manifest( # the config and adapter may be persistent. if reset: config.clear_dependencies() - adapter.clear_macro_manifest() + adapter.clear_macro_resolver() macro_hook = adapter.connections.set_query_header flags = get_flags() @@ -1000,7 +1000,7 @@ def build_manifest_state_check(self): def save_macros_to_adapter(self, adapter): macro_manifest = MacroManifest(self.manifest.macros) - adapter._macro_manifest_lazy = macro_manifest + adapter.set_macro_resolver(macro_manifest) # This executes the callable macro_hook and sets the # query headers self.macro_hook(macro_manifest) diff --git a/core/dbt/task/run_operation.py b/core/dbt/task/run_operation.py index caa1f1c7b7e..379d5ec6ab8 100644 --- a/core/dbt/task/run_operation.py +++ b/core/dbt/task/run_operation.py @@ -41,7 +41,7 @@ def _run_unsafe(self, package_name, macro_name) -> agate.Table: with adapter.connection_named("macro_{}".format(macro_name)): adapter.clear_transaction() res = adapter.execute_macro( - macro_name, project=package_name, kwargs=macro_kwargs, manifest=self.manifest + macro_name, project=package_name, kwargs=macro_kwargs, macro_resolver=self.manifest ) return res diff --git a/core/dbt/task/show.py b/core/dbt/task/show.py index d6d140898a9..961a36c6127 100644 --- a/core/dbt/task/show.py +++ b/core/dbt/task/show.py @@ -27,7 +27,7 @@ def execute(self, compiled_node, manifest): model_context = generate_runtime_model_context(compiled_node, self.config, manifest) compiled_node.compiled_code = self.adapter.execute_macro( macro_name="get_show_sql", - manifest=manifest, + macro_resolver=manifest, context_override=model_context, kwargs={ "compiled_code": model_context["compiled_code"], diff --git a/core/dbt/tests/fixtures/project.py b/core/dbt/tests/fixtures/project.py index 429207e907d..487450dcb45 100644 --- a/core/dbt/tests/fixtures/project.py +++ b/core/dbt/tests/fixtures/project.py @@ -7,6 +7,7 @@ import warnings import yaml +from dbt.parser.manifest import ManifestLoader from dbt.common.exceptions import CompilationError, DbtDatabaseError import dbt.flags as flags from dbt.config.runtime import RuntimeConfig @@ -289,7 +290,13 @@ def adapter( adapter = get_adapter(runtime_config) # We only need the base macros, not macros from dependencies, and don't want # to run 'dbt deps' here. - adapter.load_macro_manifest(base_macros_only=True) + manifest = ManifestLoader.load_macros( + runtime_config, + adapter.connections.set_query_header, + base_macros_only=True, + ) + + adapter.set_macro_resolver(manifest) yield adapter adapter.cleanup_connections() reset_adapters() @@ -450,6 +457,14 @@ def create_test_schema(self, schema_name=None): # Drop the unique test schema, usually called in test cleanup def drop_test_schema(self): + if self.adapter.get_macro_resolver() is None: + manifest = ManifestLoader.load_macros( + self.adapter.config, + self.adapter.connections.set_query_header, + base_macros_only=True, + ) + self.adapter.set_macro_resolver(manifest) + with get_connection(self.adapter): for schema_name in self.created_schemas: relation = self.adapter.Relation.create(database=self.database, schema=schema_name) diff --git a/tests/unit/test_postgres_adapter.py b/tests/unit/test_postgres_adapter.py index 1dfb7a9146a..d159f227535 100644 --- a/tests/unit/test_postgres_adapter.py +++ b/tests/unit/test_postgres_adapter.py @@ -428,9 +428,9 @@ def _mock_state_check(self): self.psycopg2.connect.return_value = self.handle self.adapter = PostgresAdapter(self.config, self.mp_context) - self.adapter._macro_manifest_lazy = load_internal_manifest_macros(self.config) + self.adapter.set_macro_resolver(load_internal_manifest_macros(self.config)) self.adapter.connections.query_header = MacroQueryStringSetter( - self.config, self.adapter._macro_manifest_lazy + self.config, self.adapter.get_macro_resolver() ) self.qh_patch = mock.patch.object(self.adapter.connections.query_header, "add")