Skip to content

Commit

Permalink
Integration Test Optimizations (#9499)
Browse files Browse the repository at this point in the history
* Cache static objects between tests to accelerate integration runs.

* Temporarily pinning to dbt-common branch for testing.

* Add changelog entry.

* Re-pin dbt-common
  • Loading branch information
peterallenwebb authored Feb 1, 2024
1 parent 1a5d692 commit db65e62
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 29 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240131-153535.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Integration Test Optimizations
time: 2024-01-31T15:35:35.691224-05:00
custom:
Author: peterallenwebb
Issue: "9498"
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240201-154956.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Accelerate integration tests with caching.
time: 2024-02-01T15:49:56.422651-05:00
custom:
Author: peterallenwebb
Issue: "9498"
36 changes: 19 additions & 17 deletions core/dbt/clients/jinja_static.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,38 @@
from typing import Any, Dict, Optional

import jinja2
from dbt_common.clients.jinja import get_environment
from dbt_common.tests import test_caching_enabled
from dbt.exceptions import MacroNamespaceNotStringError
from dbt_common.exceptions.macros import MacroNameNotStringError


_TESTING_MACRO_CACHE: Optional[Dict[str, Any]] = {}


def statically_extract_macro_calls(string, ctx, db_wrapper=None):
# set 'capture_macros' to capture undefined
env = get_environment(None, capture_macros=True)
parsed = env.parse(string)

global _TESTING_MACRO_CACHE
if test_caching_enabled() and string in _TESTING_MACRO_CACHE:
parsed = _TESTING_MACRO_CACHE.get(string, None)
func_calls = getattr(parsed, "_dbt_cached_calls")
else:
parsed = env.parse(string)
func_calls = tuple(parsed.find_all(jinja2.nodes.Call))

if test_caching_enabled():
_TESTING_MACRO_CACHE[string] = parsed
setattr(parsed, "_dbt_cached_calls", func_calls)

standard_calls = ["source", "ref", "config"]
possible_macro_calls = []
for func_call in parsed.find_all(jinja2.nodes.Call):
for func_call in func_calls:
func_name = None
if hasattr(func_call, "node") and hasattr(func_call.node, "name"):
func_name = func_call.node.name
else:
# func_call for dbt.current_timestamp macro
# Call(
# node=Getattr(
# node=Name(
# name='dbt_utils',
# ctx='load'
# ),
# attr='current_timestamp',
# ctx='load
# ),
# args=[],
# kwargs=[],
# dyn_args=None,
# dyn_kwargs=None
# )
if (
hasattr(func_call, "node")
and hasattr(func_call.node, "node")
Expand Down
32 changes: 22 additions & 10 deletions core/dbt/parser/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,28 @@ def parse_unparsed_macros(self, base_node: UnparsedMacro) -> Iterable[Macro]:
e.add_node(base_node)
raise

macro_nodes = list(ast.find_all(jinja2.nodes.Macro))

if len(macro_nodes) != 1:
# things have gone disastrously wrong, we thought we only
# parsed one block!
raise ParsingError(
f"Found multiple macros in {block.full_block}, expected 1", node=base_node
)

macro = macro_nodes[0]
if (
isinstance(ast, jinja2.nodes.Template)
and hasattr(ast, "body")
and len(ast.body) == 1
and isinstance(ast.body[0], jinja2.nodes.Macro)
):
# If the top level node in the Template is a Macro, things look
# good and this is much faster than traversing the full ast, as
# in the following else clause. It's not clear if that traversal
# is ever really needed.
macro = ast.body[0]
else:
macro_nodes = list(ast.find_all(jinja2.nodes.Macro))

if len(macro_nodes) != 1:
# things have gone disastrously wrong, we thought we only
# parsed one block!
raise ParsingError(
f"Found multiple macros in {block.full_block}, expected 1", node=base_node
)

macro = macro_nodes[0]

if not macro.name.startswith(MACRO_PREFIX):
continue
Expand Down
23 changes: 22 additions & 1 deletion core/dbt/plugins/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, List, Callable, Mapping

from dbt.contracts.graph.manifest import Manifest
from dbt_common.tests import test_caching_enabled
from dbt_common.exceptions import DbtRuntimeError
from dbt.plugins.contracts import PluginArtifacts
from dbt.plugins.manifest import PluginNodes
Expand Down Expand Up @@ -76,6 +77,9 @@ def _get_dbt_modules() -> Mapping[str, ModuleType]:
}


_MODULES_CACHE = None


class PluginManager:
PLUGIN_MODULE_PREFIX = "dbt_"
PLUGIN_ATTR_NAME = "plugins"
Expand Down Expand Up @@ -104,7 +108,16 @@ def __init__(self, plugins: List[dbtPlugin]) -> None:

@classmethod
def from_modules(cls, project_name: str) -> "PluginManager":
discovered_dbt_modules = _get_dbt_modules()

if test_caching_enabled():
global _MODULES_CACHE
if _MODULES_CACHE is None:
discovered_dbt_modules = cls.get_prefixed_modules()
_MODULES_CACHE = discovered_dbt_modules
else:
discovered_dbt_modules = _MODULES_CACHE
else:
discovered_dbt_modules = cls.get_prefixed_modules()

plugins = []
for name, module in discovered_dbt_modules.items():
Expand All @@ -118,6 +131,14 @@ def from_modules(cls, project_name: str) -> "PluginManager":
plugins.append(plugin)
return cls(plugins=plugins)

@classmethod
def get_prefixed_modules(cls):
return {
name: importlib.import_module(name)
for _, name, _ in pkgutil.iter_modules()
if name.startswith(cls.PLUGIN_MODULE_PREFIX)
}

def get_manifest_artifacts(self, manifest: Manifest) -> PluginArtifacts:
all_plugin_artifacts = {}
for hook_method in self.hooks.get("get_manifest_artifacts", []):
Expand Down
2 changes: 2 additions & 0 deletions core/dbt/tests/fixtures/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dbt_common.exceptions import CompilationError, DbtDatabaseError
from dbt.context.providers import generate_runtime_macro_context
import dbt.flags as flags
from dbt_common.tests import enable_test_caching
from dbt.config.runtime import RuntimeConfig
from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters, get_adapter_by_type
from dbt_common.events.event_manager_client import cleanup_event_logger
Expand Down Expand Up @@ -517,6 +518,7 @@ def project(
# Logbook warnings are ignored so we don't have to fork logbook to support python 3.10.
# This _only_ works for tests in `tests/` that use the project fixture.
warnings.filterwarnings("ignore", category=DeprecationWarning, module="logbook")
enable_test_caching()
log_flags = Namespace(
LOG_PATH=logs_dir,
LOG_FORMAT="json",
Expand Down
2 changes: 1 addition & 1 deletion core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
"dbt-extractor~=0.5.0",
"minimal-snowplow-tracker~=0.0.2",
"dbt-semantic-interfaces~=0.5.0a2",
"dbt-common~=0.1.0",
"dbt-common~=0.1.3",
"dbt-adapters~=0.1.0a2",
# ----
# Expect compatibility with all new versions of these packages, so lower bounds only.
Expand Down

0 comments on commit db65e62

Please sign in to comment.