Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport 9328 to 1.7.latest #9391

Merged
merged 3 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20231213-220449.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Preserve the value of vars and the --full-refresh flags when using retry.
time: 2023-12-13T22:04:49.228294-05:00
custom:
Author: peterallenwebb, ChenyuLInx
Issue: "9112"
8 changes: 4 additions & 4 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FLAGS_DEFAULTS = {
"INDIRECT_SELECTION": "eager",
"TARGET_PATH": None,
"WARN_ERROR": None,
# Cli args without user_config or env var option.
"FULL_REFRESH": False,
"STRICT_MODE": False,
Expand Down Expand Up @@ -78,7 +79,6 @@ class Flags:
def __init__(
self, ctx: Optional[Context] = None, user_config: Optional[UserConfig] = None
) -> None:

# Set the default flags.
for key, value in FLAGS_DEFAULTS.items():
object.__setattr__(self, key, value)
Expand Down Expand Up @@ -120,7 +120,6 @@ def _assign_params(
# respected over DBT_PRINT or --print.
new_name: Union[str, None] = None
if param_name in DEPRECATED_PARAMS:

# Deprecated env vars can only be set via env var.
# We use the deprecated option in click to serialize the value
# from the env var string.
Expand Down Expand Up @@ -315,7 +314,6 @@ def command_params(command: CliCommand, args_dict: Dict[str, Any]) -> CommandPar
default_args = set([x.lower() for x in FLAGS_DEFAULTS.keys()])

res = command.to_list()

for k, v in args_dict.items():
k = k.lower()
# if a "which" value exists in the args dict, it should match the command provided
Expand All @@ -327,7 +325,9 @@ def command_params(command: CliCommand, args_dict: Dict[str, Any]) -> CommandPar
continue

# param was assigned from defaults and should not be included
if k not in (cmd_args | prnt_args) - default_args:
if k not in (cmd_args | prnt_args) or (
k in default_args and v == FLAGS_DEFAULTS[k.upper()]
):
continue

# if the param is in parent args, it should come before the arg name
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,12 +638,12 @@ def run(ctx, **kwargs):
@p.target
@p.state
@p.threads
@p.full_refresh
@requires.postflight
@requires.preflight
@requires.profile
@requires.project
@requires.runtime_config
@requires.manifest
def retry(ctx, **kwargs):
"""Retry the nodes that failed in the previous run."""
task = RetryTask(
Expand Down
22 changes: 7 additions & 15 deletions core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from dbt.events.helpers import get_json_string_utcnow
from dbt.events.types import MainEncounteredError, MainStackTrace
from dbt.exceptions import Exception as DbtException, DbtProjectError, FailFastError
from dbt.parser.manifest import ManifestLoader, write_manifest
from dbt.parser.manifest import parse_manifest
from dbt.profiler import profiler
from dbt.tracking import active_user, initialize_from_flags, track_run
from dbt.utils import cast_dict_to_dict_of_strings
from dbt.plugins import set_up_plugin_manager, get_plugin_manager
from dbt.plugins import set_up_plugin_manager


from click import Context
from functools import update_wrapper
Expand Down Expand Up @@ -264,23 +265,14 @@ def wrapper(*args, **kwargs):
raise DbtProjectError("profile, project, and runtime_config required for manifest")

runtime_config = ctx.obj["runtime_config"]
register_adapter(runtime_config)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-locating the call to register_adapter() makes me nervous. With this change, it looks like register_adapter() was previously called even when ctx.obj.get("manifest) was already set. But now, it will not be called in that scenario. Do you know if that is safe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it is still being called in line 271 in this fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but it is only called if the if condition evaluates to True. What about the case where it is False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Update to restore previous behavior

# a manifest has already been set on the context, so don't overwrite it
if ctx.obj.get("manifest") is None:
manifest = ManifestLoader.get_full_manifest(
runtime_config,
write_perf_info=write_perf_info,
ctx.obj["manifest"] = parse_manifest(
runtime_config, write_perf_info, write, ctx.obj["flags"].write_json
)

ctx.obj["manifest"] = manifest
if write and ctx.obj["flags"].write_json:
write_manifest(manifest, runtime_config.project_target_path)
pm = get_plugin_manager(runtime_config.project_name)
plugin_artifacts = pm.get_manifest_artifacts(manifest)
for path, plugin_artifact in plugin_artifacts.items():
plugin_artifact.write(path)

else:
register_adapter(runtime_config)
return func(*args, **kwargs)

return update_wrapper(wrapper, func)
Expand Down
17 changes: 11 additions & 6 deletions core/dbt/contracts/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
from dbt.exceptions import IncompatibleSchemaError


def load_result_state(results_path) -> Optional[RunResultsArtifact]:
if results_path.exists() and results_path.is_file():
try:
return RunResultsArtifact.read_and_check_versions(str(results_path))
except IncompatibleSchemaError as exc:
exc.add_filename(str(results_path))
raise

Check warning on line 18 in core/dbt/contracts/state.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/contracts/state.py#L16-L18

Added lines #L16 - L18 were not covered by tests
return None


class PreviousState:
def __init__(self, state_path: Path, target_path: Path, project_root: Path) -> None:
self.state_path: Path = state_path
Expand All @@ -32,12 +42,7 @@
raise

results_path = self.project_root / self.state_path / "run_results.json"
if results_path.exists() and results_path.is_file():
try:
self.results = RunResultsArtifact.read_and_check_versions(str(results_path))
except IncompatibleSchemaError as exc:
exc.add_filename(str(results_path))
raise
self.results = load_result_state(results_path)

sources_path = self.project_root / self.state_path / "sources.json"
if sources_path.exists() and sources_path.is_file():
Expand Down
21 changes: 17 additions & 4 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
get_adapter,
get_relation_class_by_name,
get_adapter_package_names,
register_adapter,
)
from dbt.constants import (
MANIFEST_FILE_NAME,
Expand Down Expand Up @@ -278,7 +279,6 @@
reset: bool = False,
write_perf_info=False,
) -> Manifest:

adapter = get_adapter(config) # type: ignore
# reset is set in a TaskManager load_manifest call, since
# the config and adapter may be persistent.
Expand Down Expand Up @@ -590,7 +590,6 @@
node.depends_on
for resolved_ref in resolved_model_refs:
if resolved_ref.deprecation_date:

if resolved_ref.deprecation_date < datetime.datetime.now().astimezone():
event_cls = DeprecatedReference
else:
Expand Down Expand Up @@ -1733,7 +1732,6 @@


def _process_sources_for_node(manifest: Manifest, current_project: str, node: ManifestNode):

if isinstance(node, SeedNode):
return

Expand Down Expand Up @@ -1775,7 +1773,6 @@
# This is called in task.rpc.sql_commands when a "dynamic" node is
# created in the manifest, in 'add_refs'
def process_node(config: RuntimeConfig, manifest: Manifest, node: ManifestNode):

_process_sources_for_node(manifest, config.project_name, node)
_process_refs(manifest, config.project_name, node, config.dependencies)
ctx = generate_runtime_docs_context(config, node, manifest, config.project_name)
Expand All @@ -1793,3 +1790,19 @@
manifest.write(path)

write_semantic_manifest(manifest=manifest, target_path=target_path)


def parse_manifest(runtime_config, write_perf_info, write, write_json):
register_adapter(runtime_config)
manifest = ManifestLoader.get_full_manifest(
runtime_config,
write_perf_info=write_perf_info,
)

if write and write_json:
write_manifest(manifest, runtime_config.project_target_path)
pm = plugins.get_plugin_manager(runtime_config.project_name)
plugin_artifacts = pm.get_manifest_artifacts(manifest)
for path, plugin_artifact in plugin_artifacts.items():
plugin_artifact.write(path)

Check warning on line 1807 in core/dbt/parser/manifest.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/manifest.py#L1807

Added line #L1807 was not covered by tests
return manifest
84 changes: 48 additions & 36 deletions core/dbt/task/retry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from pathlib import Path
from click import get_current_context
from click.core import ParameterSource

from dbt.cli.flags import Flags
from dbt.flags import set_flags, get_flags
from dbt.cli.types import Command as CliCommand
from dbt.config import RuntimeConfig
from dbt.contracts.results import NodeStatus
from dbt.contracts.state import PreviousState
from dbt.contracts.state import load_result_state
from dbt.exceptions import DbtRuntimeError
from dbt.graph import GraphQueue
from dbt.task.base import ConfiguredTask
Expand All @@ -17,9 +20,10 @@
from dbt.task.seed import SeedTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask
from dbt.parser.manifest import parse_manifest

RETRYABLE_STATUSES = {NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped, NodeStatus.RuntimeErr}
OVERRIDE_PARENT_FLAGS = {
IGNORE_PARENT_FLAGS = {
"log_path",
"output_path",
"profiles_dir",
Expand All @@ -28,8 +32,11 @@
"defer_state",
"deprecated_state",
"target_path",
"warn_error",
}

ALLOW_CLI_OVERRIDE_FLAGS = {"vars"}

TASK_DICT = {
"build": BuildTask,
"compile": CompileTask,
Expand Down Expand Up @@ -57,59 +64,64 @@

class RetryTask(ConfiguredTask):
def __init__(self, args, config, manifest) -> None:
super().__init__(args, config, manifest)

state_path = self.args.state or self.config.target_path

if self.args.warn_error:
RETRYABLE_STATUSES.add(NodeStatus.Warn)

self.previous_state = PreviousState(
state_path=Path(state_path),
target_path=Path(self.config.target_path),
project_root=Path(self.config.project_root),
# load previous run results
state_path = args.state or config.target_path
self.previous_results = load_result_state(
Path(config.project_root) / Path(state_path) / "run_results.json"
)

if not self.previous_state.results:
if not self.previous_results:
raise DbtRuntimeError(
f"Could not find previous run in '{state_path}' target directory"
)

self.previous_args = self.previous_state.results.args
self.previous_args = self.previous_results.args
self.previous_command_name = self.previous_args.get("which")
self.task_class = TASK_DICT.get(self.previous_command_name) # type: ignore

def run(self):
unique_ids = set(
[
result.unique_id
for result in self.previous_state.results.results
if result.status in RETRYABLE_STATUSES
]
)

cli_command = CMD_DICT.get(self.previous_command_name)
# Reslove flags and config
if args.warn_error:
RETRYABLE_STATUSES.add(NodeStatus.Warn)

cli_command = CMD_DICT.get(self.previous_command_name) # type: ignore
# Remove these args when their default values are present, otherwise they'll raise an exception
args_to_remove = {
"show": lambda x: True,
"resource_types": lambda x: x == [],
"warn_error_options": lambda x: x == {"exclude": [], "include": []},
}

for k, v in args_to_remove.items():
if k in self.previous_args and v(self.previous_args[k]):
del self.previous_args[k]

previous_args = {
k: v for k, v in self.previous_args.items() if k not in OVERRIDE_PARENT_FLAGS
k: v for k, v in self.previous_args.items() if k not in IGNORE_PARENT_FLAGS
}
click_context = get_current_context()
current_args = {
k: v
for k, v in args.__dict__.items()
if k in IGNORE_PARENT_FLAGS
or (
click_context.get_parameter_source(k) == ParameterSource.COMMANDLINE
and k in ALLOW_CLI_OVERRIDE_FLAGS
)
}
current_args = {k: v for k, v in self.args.__dict__.items() if k in OVERRIDE_PARENT_FLAGS}
combined_args = {**previous_args, **current_args}

retry_flags = Flags.from_dict(cli_command, combined_args)
retry_flags = Flags.from_dict(cli_command, combined_args) # type: ignore
set_flags(retry_flags)
retry_config = RuntimeConfig.from_args(args=retry_flags)

# Parse manifest using resolved config/flags
manifest = parse_manifest(retry_config, False, True, retry_flags.write_json) # type: ignore
super().__init__(args, retry_config, manifest)
self.task_class = TASK_DICT.get(self.previous_command_name) # type: ignore

def run(self):
unique_ids = set(
[
result.unique_id
for result in self.previous_results.results
if result.status in RETRYABLE_STATUSES
]
)

class TaskWrapper(self.task_class):
def get_graph_queue(self):
new_graph = self.graph.get_subset_graph(unique_ids)
Expand All @@ -120,8 +132,8 @@ def get_graph_queue(self):
)

task = TaskWrapper(
retry_flags,
retry_config,
get_flags(),
self.config,
self.manifest,
)

Expand Down
38 changes: 37 additions & 1 deletion tests/functional/retry/test_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ def test_previous_run(self, project):
write_file(models__sample_model, "models", "sample_model.sql")

def test_warn_error(self, project):
# Regular build
# Our test command should succeed when run normally...
results = run_dbt(["build", "--select", "second_model"])

# ...but it should fail when run with warn-error, due to a warning...
results = run_dbt(["--warn-error", "build", "--select", "second_model"], expect_pass=False)

expected_statuses = {
Expand Down Expand Up @@ -291,3 +294,36 @@ def test_retry(self, project):
run_dbt(["run", "--project-dir", "proj_location_1"], expect_pass=False)
move(proj_location_1, proj_location_2)
run_dbt(["retry", "--project-dir", "proj_location_2"], expect_pass=False)


class TestRetryVars:
@pytest.fixture(scope="class")
def models(self):
return {
"sample_model.sql": "select {{ var('myvar_a', '1') + var('myvar_b', '2') }} as mycol",
}

def test_retry(self, project):
# pass because default vars works
run_dbt(["run"])
run_dbt(["run", "--vars", '{"myvar_a": "12", "myvar_b": "3 4"}'], expect_pass=False)
# fail because vars are invalid, this shows that the last passed vars are being used
# instead of using the default vars
run_dbt(["retry"], expect_pass=False)
results = run_dbt(["retry", "--vars", '{"myvar_a": "12", "myvar_b": "34"}'])
assert len(results) == 1


class TestRetryFullRefresh:
@pytest.fixture(scope="class")
def models(self):
return {
"sample_model.sql": "{% if flags.FULL_REFRESH %} this is invalid sql {% else %} select 1 as mycol {% endif %}",
}

def test_retry(self, project):
# This run should fail with invalid sql...
run_dbt(["run", "--full-refresh"], expect_pass=False)
# ...and so should this one, since the effect of the full-refresh parameter should persist.
results = run_dbt(["retry"], expect_pass=False)
assert len(results) == 1