diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index dffd4f65f2f..9d06dcdc674 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -9,7 +9,6 @@ Any, Callable, Dict, - Iterable, Iterator, List, Mapping, @@ -28,6 +27,7 @@ ConstraintType, ModelLevelConstraint, ) +from dbt.adapters.contracts.macros import MacroResolver import agate import pytz @@ -62,7 +62,7 @@ Integer, ) from dbt.common.clients.jinja import CallableMacroGenerator -from dbt.contracts.graph.manifest import Manifest, MacroManifest +from dbt.contracts.graph.manifest import Manifest from dbt.common.events.functions import fire_event, warn_or_error from dbt.adapters.events.types import ( CacheMiss, @@ -254,7 +254,20 @@ def __init__(self, config, mp_context: SpawnContext) -> None: self.config = config self.cache = RelationsCache(log_cache_events=config.log_cache_events) self.connections = self.ConnectionManager(config, mp_context) - self._macro_manifest_lazy: Optional[MacroManifest] = None + self._macro_resolver: Optional[MacroResolver] = None + + ### + # Methods to set / access a macro resolver + ### + def set_macro_resolver(self, macro_resolver: MacroResolver) -> None: + self._macro_resolver = macro_resolver + + def get_macro_resolver(self) -> Optional[MacroResolver]: + return self._macro_resolver + + def clear_macro_resolver(self) -> None: + if self._macro_resolver is not None: + self._macro_resolver = None ### # Methods that pass through to the connection manager @@ -367,39 +380,6 @@ def type(cls) -> str: """ return cls.ConnectionManager.TYPE - @property - def _macro_manifest(self) -> MacroManifest: - if self._macro_manifest_lazy is None: - return self.load_macro_manifest() - return self._macro_manifest_lazy - - def check_macro_manifest(self) -> Optional[MacroManifest]: - """Return the internal manifest (used for executing macros) if it's - been initialized, otherwise return None. - """ - return self._macro_manifest_lazy - - def load_macro_manifest(self, base_macros_only=False) -> MacroManifest: - # base_macros_only is for the test framework - if self._macro_manifest_lazy is None: - # avoid a circular import - from dbt.parser.manifest import ManifestLoader - - manifest = ManifestLoader.load_macros( - self.config, - self.connections.set_query_header, - base_macros_only=base_macros_only, - ) - # TODO CT-211 - self._macro_manifest_lazy = manifest # type: ignore[assignment] - # TODO CT-211 - return self._macro_manifest_lazy # type: ignore[return-value] - - def clear_macro_manifest(self): - if self._macro_manifest_lazy is not None: - self._macro_manifest_lazy = None - - ### # Caching methods ### def _schema_is_cached(self, database: Optional[str], schema: str) -> bool: @@ -1054,11 +1034,10 @@ def convert_agate_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[ def execute_macro( self, macro_name: str, - manifest: Optional[Manifest] = None, + macro_resolver: Optional[MacroResolver] = None, project: Optional[str] = None, context_override: Optional[Dict[str, Any]] = None, kwargs: Optional[Dict[str, Any]] = None, - text_only_columns: Optional[Iterable[str]] = None, ) -> AttrDict: """Look macro_name up in the manifest and execute its results. @@ -1078,13 +1057,11 @@ def execute_macro( if context_override is None: context_override = {} - if manifest is None: - # TODO CT-211 - manifest = self._macro_manifest # type: ignore[assignment] - # TODO CT-211 - macro = manifest.find_macro_by_name( # type: ignore[union-attr] - macro_name, self.config.project_name, project - ) + resolver = macro_resolver or self._macro_resolver + if resolver is None: + raise DbtInternalError("macro resolver was None when calling execute_macro!") + + macro = resolver.find_macro_by_name(macro_name, self.config.project_name, project) if macro is None: if project is None: package_name = "any package" @@ -1104,7 +1081,7 @@ def execute_macro( # TODO CT-211 macro=macro, config=self.config, - manifest=manifest, # type: ignore[arg-type] + manifest=resolver, # type: ignore[arg-type] package_name=project, ) macro_context.update(context_override) @@ -1140,7 +1117,7 @@ def _get_one_catalog( kwargs=kwargs, # pass in the full manifest so we get any local project # overrides - manifest=manifest, + macro_resolver=manifest, ) results = self._catalog_filter_table(table, manifest) # type: ignore[arg-type] @@ -1162,7 +1139,7 @@ def _get_one_catalog_by_relations( kwargs=kwargs, # pass in the full manifest, so we get any local project # overrides - manifest=manifest, + macro_resolver=manifest, ) results = self._catalog_filter_table(table, manifest) # type: ignore[arg-type] @@ -1273,7 +1250,7 @@ def calculate_freshness( AttrDict, # current: contains AdapterResponse + agate.Table agate.Table, # previous: just table ] - result = self.execute_macro(FRESHNESS_MACRO_NAME, kwargs=kwargs, manifest=manifest) + result = self.execute_macro(FRESHNESS_MACRO_NAME, kwargs=kwargs, macro_resolver=manifest) if isinstance(result, agate.Table): warn_or_error(CollectFreshnessReturnSignature()) adapter_response = None @@ -1310,7 +1287,7 @@ def calculate_freshness_from_metadata( "relations": [source], } result = self.execute_macro( - GET_RELATION_LAST_MODIFIED_MACRO_NAME, kwargs=kwargs, manifest=manifest + GET_RELATION_LAST_MODIFIED_MACRO_NAME, kwargs=kwargs, macro_resolver=manifest ) adapter_response, table = result.response, result.table # type: ignore[attr-defined] diff --git a/core/dbt/adapters/contracts/macros.py b/core/dbt/adapters/contracts/macros.py new file mode 100644 index 00000000000..5011a337a39 --- /dev/null +++ b/core/dbt/adapters/contracts/macros.py @@ -0,0 +1,11 @@ +from typing import Optional +from typing_extensions import Protocol + +from dbt.common.clients.jinja import MacroProtocol + + +class MacroResolver(Protocol): + def find_macro_by_name( + self, name: str, root_project_name: str, package: Optional[str] + ) -> Optional[MacroProtocol]: + raise NotImplementedError("find_macro_by_name not implemented") diff --git a/core/dbt/adapters/protocol.py b/core/dbt/adapters/protocol.py index b182878ae80..d83b8b582b3 100644 --- a/core/dbt/adapters/protocol.py +++ b/core/dbt/adapters/protocol.py @@ -5,6 +5,7 @@ import agate from dbt.adapters.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse +from dbt.adapters.contracts.macros import MacroResolver from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig from dbt.contracts.graph.model_config import BaseConfig from dbt.contracts.graph.manifest import Manifest @@ -66,6 +67,15 @@ class AdapterProtocol( # type: ignore[misc] def __init__(self, config: AdapterRequiredConfig) -> None: ... + def set_macro_resolver(self, macro_resolver: MacroResolver) -> None: + ... + + def get_macro_resolver(self) -> Optional[MacroResolver]: + ... + + def clear_macro_resolver(self) -> None: + ... + @classmethod def type(cls) -> str: pass diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 2ecb5a0b84a..22f8460cd1e 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -15,6 +15,7 @@ from typing_extensions import Protocol from dbt.adapters.base.column import Column +from dbt.common.clients.jinja import MacroProtocol from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names from dbt.common.clients import agate_helper from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack @@ -1355,7 +1356,7 @@ class MacroContext(ProviderContext): def __init__( self, - model: Macro, + model: MacroProtocol, config: RuntimeConfig, manifest: Manifest, provider: Provider, @@ -1512,7 +1513,7 @@ def generate_runtime_model_context( def generate_runtime_macro_context( - macro: Macro, + macro: MacroProtocol, config: RuntimeConfig, manifest: Manifest, package_name: Optional[str], diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index e973b5f3592..c952db063f4 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -286,7 +286,7 @@ def get_full_manifest( # the config and adapter may be persistent. if reset: config.clear_dependencies() - adapter.clear_macro_manifest() + adapter.clear_macro_resolver() macro_hook = adapter.connections.set_query_header flags = get_flags() @@ -1000,7 +1000,7 @@ def build_manifest_state_check(self): def save_macros_to_adapter(self, adapter): macro_manifest = MacroManifest(self.manifest.macros) - adapter._macro_manifest_lazy = macro_manifest + adapter.set_macro_resolver(macro_manifest) # This executes the callable macro_hook and sets the # query headers self.macro_hook(macro_manifest) diff --git a/core/dbt/task/run_operation.py b/core/dbt/task/run_operation.py index caa1f1c7b7e..379d5ec6ab8 100644 --- a/core/dbt/task/run_operation.py +++ b/core/dbt/task/run_operation.py @@ -41,7 +41,7 @@ def _run_unsafe(self, package_name, macro_name) -> agate.Table: with adapter.connection_named("macro_{}".format(macro_name)): adapter.clear_transaction() res = adapter.execute_macro( - macro_name, project=package_name, kwargs=macro_kwargs, manifest=self.manifest + macro_name, project=package_name, kwargs=macro_kwargs, macro_resolver=self.manifest ) return res diff --git a/core/dbt/task/show.py b/core/dbt/task/show.py index d6d140898a9..961a36c6127 100644 --- a/core/dbt/task/show.py +++ b/core/dbt/task/show.py @@ -27,7 +27,7 @@ def execute(self, compiled_node, manifest): model_context = generate_runtime_model_context(compiled_node, self.config, manifest) compiled_node.compiled_code = self.adapter.execute_macro( macro_name="get_show_sql", - manifest=manifest, + macro_resolver=manifest, context_override=model_context, kwargs={ "compiled_code": model_context["compiled_code"], diff --git a/core/dbt/tests/fixtures/project.py b/core/dbt/tests/fixtures/project.py index 429207e907d..487450dcb45 100644 --- a/core/dbt/tests/fixtures/project.py +++ b/core/dbt/tests/fixtures/project.py @@ -7,6 +7,7 @@ import warnings import yaml +from dbt.parser.manifest import ManifestLoader from dbt.common.exceptions import CompilationError, DbtDatabaseError import dbt.flags as flags from dbt.config.runtime import RuntimeConfig @@ -289,7 +290,13 @@ def adapter( adapter = get_adapter(runtime_config) # We only need the base macros, not macros from dependencies, and don't want # to run 'dbt deps' here. - adapter.load_macro_manifest(base_macros_only=True) + manifest = ManifestLoader.load_macros( + runtime_config, + adapter.connections.set_query_header, + base_macros_only=True, + ) + + adapter.set_macro_resolver(manifest) yield adapter adapter.cleanup_connections() reset_adapters() @@ -450,6 +457,14 @@ def create_test_schema(self, schema_name=None): # Drop the unique test schema, usually called in test cleanup def drop_test_schema(self): + if self.adapter.get_macro_resolver() is None: + manifest = ManifestLoader.load_macros( + self.adapter.config, + self.adapter.connections.set_query_header, + base_macros_only=True, + ) + self.adapter.set_macro_resolver(manifest) + with get_connection(self.adapter): for schema_name in self.created_schemas: relation = self.adapter.Relation.create(database=self.database, schema=schema_name) diff --git a/tests/unit/test_postgres_adapter.py b/tests/unit/test_postgres_adapter.py index 1dfb7a9146a..d159f227535 100644 --- a/tests/unit/test_postgres_adapter.py +++ b/tests/unit/test_postgres_adapter.py @@ -428,9 +428,9 @@ def _mock_state_check(self): self.psycopg2.connect.return_value = self.handle self.adapter = PostgresAdapter(self.config, self.mp_context) - self.adapter._macro_manifest_lazy = load_internal_manifest_macros(self.config) + self.adapter.set_macro_resolver(load_internal_manifest_macros(self.config)) self.adapter.connections.query_header = MacroQueryStringSetter( - self.config, self.adapter._macro_manifest_lazy + self.config, self.adapter.get_macro_resolver() ) self.qh_patch = mock.patch.object(self.adapter.connections.query_header, "add")