Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix full-refresh and vars for retry #9328

Merged
merged 15 commits into from
Jan 10, 2024
Merged
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20231213-220449.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Preserve the value of vars and the --full-refresh flags when using retry.
time: 2023-12-13T22:04:49.228294-05:00
custom:
Author: peterallenwebb, ChenyuLInx
Issue: "9112"
8 changes: 4 additions & 4 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
FLAGS_DEFAULTS = {
"INDIRECT_SELECTION": "eager",
"TARGET_PATH": None,
"WARN_ERROR": None,
# Cli args without project_flags or env var option.
"FULL_REFRESH": False,
"STRICT_MODE": False,
Expand Down Expand Up @@ -81,7 +82,6 @@ class Flags:
def __init__(
self, ctx: Optional[Context] = None, project_flags: Optional[ProjectFlags] = None
) -> None:

# Set the default flags.
for key, value in FLAGS_DEFAULTS.items():
object.__setattr__(self, key, value)
Expand Down Expand Up @@ -123,7 +123,6 @@ def _assign_params(
# respected over DBT_PRINT or --print.
new_name: Union[str, None] = None
if param_name in DEPRECATED_PARAMS:

# Deprecated env vars can only be set via env var.
# We use the deprecated option in click to serialize the value
# from the env var string.
Expand Down Expand Up @@ -343,7 +342,6 @@ def command_params(command: CliCommand, args_dict: Dict[str, Any]) -> CommandPar
default_args = set([x.lower() for x in FLAGS_DEFAULTS.keys()])

res = command.to_list()

for k, v in args_dict.items():
k = k.lower()
# if a "which" value exists in the args dict, it should match the command provided
Expand All @@ -355,7 +353,9 @@ def command_params(command: CliCommand, args_dict: Dict[str, Any]) -> CommandPar
continue

# param was assigned from defaults and should not be included
if k not in (cmd_args | prnt_args) - default_args:
if k not in (cmd_args | prnt_args) or (
k in default_args and v == FLAGS_DEFAULTS[k.upper()]
):
continue

# if the param is in parent args, it should come before the arg name
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,12 +633,12 @@ def run(ctx, **kwargs):
@p.target
@p.state
@p.threads
@p.full_refresh
@requires.postflight
@requires.preflight
@requires.profile
@requires.project
@requires.runtime_config
@requires.manifest
def retry(ctx, **kwargs):
"""Retry the nodes that failed in the previous run."""
task = RetryTask(
Expand Down
25 changes: 5 additions & 20 deletions core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import dbt.tracking
from dbt.common.invocation import reset_invocation_id
from dbt.mp_context import get_mp_context
from dbt.version import installed as installed_version
from dbt.adapters.factory import adapter_management, register_adapter, get_adapter
from dbt.adapters.factory import adapter_management
from dbt.flags import set_flags, get_flag_dict
from dbt.cli.exceptions import (
ExceptionExit,
Expand All @@ -11,7 +10,6 @@
from dbt.cli.flags import Flags
from dbt.config import RuntimeConfig
from dbt.config.runtime import load_project, load_profile, UnsetProfile
from dbt.context.providers import generate_runtime_macro_context

from dbt.common.events.base_types import EventLevel
from dbt.common.events.functions import (
Expand All @@ -28,11 +26,11 @@
from dbt.events.types import CommandCompleted, MainEncounteredError, MainStackTrace, ResourceReport
from dbt.common.exceptions import DbtBaseException as DbtException
from dbt.exceptions import DbtProjectError, FailFastError
from dbt.parser.manifest import ManifestLoader, write_manifest
from dbt.parser.manifest import parse_manifest
from dbt.profiler import profiler
from dbt.tracking import active_user, initialize_from_flags, track_run
from dbt.common.utils import cast_dict_to_dict_of_strings
from dbt.plugins import set_up_plugin_manager, get_plugin_manager
from dbt.plugins import set_up_plugin_manager

from click import Context
from functools import update_wrapper
Expand Down Expand Up @@ -273,25 +271,12 @@ def wrapper(*args, **kwargs):
raise DbtProjectError("profile, project, and runtime_config required for manifest")

runtime_config = ctx.obj["runtime_config"]
register_adapter(runtime_config, get_mp_context())
adapter = get_adapter(runtime_config)
adapter.set_macro_context_generator(generate_runtime_macro_context)

# a manifest has already been set on the context, so don't overwrite it
if ctx.obj.get("manifest") is None:
manifest = ManifestLoader.get_full_manifest(
runtime_config,
write_perf_info=write_perf_info,
ctx.obj["manifest"] = parse_manifest(
runtime_config, write_perf_info, write, ctx.obj["flags"].write_json
)

ctx.obj["manifest"] = manifest
if write and ctx.obj["flags"].write_json:
write_manifest(manifest, runtime_config.project_target_path)
pm = get_plugin_manager(runtime_config.project_name)
plugin_artifacts = pm.get_manifest_artifacts(manifest)
for path, plugin_artifact in plugin_artifacts.items():
plugin_artifact.write(path)

return func(*args, **kwargs)

return update_wrapper(wrapper, func)
Expand Down
17 changes: 11 additions & 6 deletions core/dbt/contracts/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
from dbt.exceptions import IncompatibleSchemaError


def load_result_state(results_path) -> Optional[RunResultsArtifact]:
if results_path.exists() and results_path.is_file():
try:
return RunResultsArtifact.read_and_check_versions(str(results_path))
except IncompatibleSchemaError as exc:
exc.add_filename(str(results_path))
raise

Check warning on line 18 in core/dbt/contracts/state.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/contracts/state.py#L16-L18

Added lines #L16 - L18 were not covered by tests
return None


class PreviousState:
def __init__(self, state_path: Path, target_path: Path, project_root: Path) -> None:
self.state_path: Path = state_path
Expand All @@ -32,12 +42,7 @@
raise

results_path = self.project_root / self.state_path / "run_results.json"
if results_path.exists() and results_path.is_file():
try:
self.results = RunResultsArtifact.read_and_check_versions(str(results_path))
except IncompatibleSchemaError as exc:
exc.add_filename(str(results_path))
raise
self.results = load_result_state(results_path)

sources_path = self.project_root / self.state_path / "sources.json"
if sources_path.exists() and sources_path.is_file():
Expand Down
26 changes: 21 additions & 5 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from dbt.common.events.base_types import EventLevel
import json
import pprint
from dbt.mp_context import get_mp_context
import msgpack

import dbt.exceptions
Expand All @@ -35,6 +36,7 @@
get_adapter,
get_relation_class_by_name,
get_adapter_package_names,
register_adapter,
)
from dbt.constants import (
MANIFEST_FILE_NAME,
Expand Down Expand Up @@ -75,7 +77,7 @@
from dbt.context.docs import generate_runtime_docs_context
from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace
from dbt.context.configured import generate_macro_context
from dbt.context.providers import ParseProvider
from dbt.context.providers import ParseProvider, generate_runtime_macro_context
from dbt.contracts.files import FileHash, ParseFileType, SchemaSourceFile
from dbt.parser.read_files import (
ReadFilesFromFileSystem,
Expand Down Expand Up @@ -281,7 +283,6 @@
reset: bool = False,
write_perf_info=False,
) -> Manifest:

adapter = get_adapter(config) # type: ignore
# reset is set in a TaskManager load_manifest call, since
# the config and adapter may be persistent.
Expand Down Expand Up @@ -593,7 +594,6 @@
node.depends_on
for resolved_ref in resolved_model_refs:
if resolved_ref.deprecation_date:

if resolved_ref.deprecation_date < datetime.datetime.now().astimezone():
event_cls = DeprecatedReference
else:
Expand Down Expand Up @@ -1738,7 +1738,6 @@


def _process_sources_for_node(manifest: Manifest, current_project: str, node: ManifestNode):

if isinstance(node, SeedNode):
return

Expand Down Expand Up @@ -1780,7 +1779,6 @@
# This is called in task.rpc.sql_commands when a "dynamic" node is
# created in the manifest, in 'add_refs'
def process_node(config: RuntimeConfig, manifest: Manifest, node: ManifestNode):

_process_sources_for_node(manifest, config.project_name, node)
_process_refs(manifest, config.project_name, node, config.dependencies)
ctx = generate_runtime_docs_context(config, node, manifest, config.project_name)
Expand All @@ -1798,3 +1796,21 @@
manifest.write(path)

write_semantic_manifest(manifest=manifest, target_path=target_path)


def parse_manifest(runtime_config, write_perf_info, write, write_json):
register_adapter(runtime_config, get_mp_context())
adapter = get_adapter(runtime_config)
adapter.set_macro_context_generator(generate_runtime_macro_context)
manifest = ManifestLoader.get_full_manifest(
runtime_config,
write_perf_info=write_perf_info,
)

if write and write_json:
write_manifest(manifest, runtime_config.project_target_path)
pm = plugins.get_plugin_manager(runtime_config.project_name)
plugin_artifacts = pm.get_manifest_artifacts(manifest)
for path, plugin_artifact in plugin_artifacts.items():
plugin_artifact.write(path)

Check warning on line 1815 in core/dbt/parser/manifest.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/manifest.py#L1815

Added line #L1815 was not covered by tests
return manifest
84 changes: 48 additions & 36 deletions core/dbt/task/retry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from pathlib import Path
from click import get_current_context
from click.core import ParameterSource

from dbt.cli.flags import Flags
from dbt.flags import set_flags, get_flags
from dbt.cli.types import Command as CliCommand
from dbt.config import RuntimeConfig
from dbt.contracts.results import NodeStatus
from dbt.contracts.state import PreviousState
from dbt.contracts.state import load_result_state
from dbt.common.exceptions import DbtRuntimeError
from dbt.graph import GraphQueue
from dbt.task.base import ConfiguredTask
Expand All @@ -17,9 +20,10 @@
from dbt.task.seed import SeedTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask
from dbt.parser.manifest import parse_manifest

RETRYABLE_STATUSES = {NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped, NodeStatus.RuntimeErr}
OVERRIDE_PARENT_FLAGS = {
IGNORE_PARENT_FLAGS = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating the name to better reflect what this list is.

"log_path",
"output_path",
"profiles_dir",
Expand All @@ -28,8 +32,11 @@
"defer_state",
"deprecated_state",
"target_path",
"warn_error",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@graciegoheen how do we feel about warn_error being a thing that's being carried over?
Current behavior:

  • dbt run --warn-error failed because of some warning seen as error
  • 'dbt retry' will not treat those warnings as error
  • 'dbt retry --warn-error' will be the way to still treat warnings as error.
    This is the previous behavior we have. Do we think this is good? or do we want to change it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peterallenwebb I think we talked about this became a behavior change in my version. But in the latest version this is actually not.

Copy link
Contributor

@graciegoheen graciegoheen Jan 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • dbt run --warn-error failed because of some warning seen as error
  • 'dbt retry' will not treat those warnings as error

This is surprising to me. It ties back to the desire to have a consistent policy here. I would assume that all flags passed to the original command are pulled through to dbt retry. Is this a specific exception for --warn-error? cc: @jtcohen6 do you remember if this was intentional or not?

Ultimately, if we're worried about breaking something in a backport, we could address this for 1.8.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So warn-error is a flag, so if we don't ignore it in retry there's actually no way to specify no-warn-error in retry.

dbt run --warn-error failed because of some warning seen as error
'dbt retry' will not treat those warnings as error

Just to be super clear this is the current behavior of retry.

}

ALLOW_OVERRIDE_FLAGS = {"vars"}

TASK_DICT = {
"build": BuildTask,
"compile": CompileTask,
Expand Down Expand Up @@ -57,59 +64,64 @@

class RetryTask(ConfiguredTask):
def __init__(self, args, config, manifest) -> None:
super().__init__(args, config, manifest)

state_path = self.args.state or self.config.target_path

if self.args.warn_error:
RETRYABLE_STATUSES.add(NodeStatus.Warn)

self.previous_state = PreviousState(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not need the full previous state here, just run results.
Loading the whole state could be memory intensive for large projects(manifest can get large.)

state_path=Path(state_path),
target_path=Path(self.config.target_path),
project_root=Path(self.config.project_root),
# load previous run results
state_path = args.state or config.target_path
self.previous_results = load_result_state(
Path(config.project_root) / Path(state_path) / "run_results.json"
)

if not self.previous_state.results:
if not self.previous_results:
raise DbtRuntimeError(
f"Could not find previous run in '{state_path}' target directory"
)

self.previous_args = self.previous_state.results.args
self.previous_args = self.previous_results.args
self.previous_command_name = self.previous_args.get("which")
self.task_class = TASK_DICT.get(self.previous_command_name) # type: ignore

def run(self):
unique_ids = set(
[
result.unique_id
for result in self.previous_state.results.results
if result.status in RETRYABLE_STATUSES
]
)

cli_command = CMD_DICT.get(self.previous_command_name)
# Reslove flags and config
if args.warn_error:
RETRYABLE_STATUSES.add(NodeStatus.Warn)

cli_command = CMD_DICT.get(self.previous_command_name) # type: ignore
# Remove these args when their default values are present, otherwise they'll raise an exception
args_to_remove = {
"show": lambda x: True,
"resource_types": lambda x: x == [],
"warn_error_options": lambda x: x == {"exclude": [], "include": []},
}

for k, v in args_to_remove.items():
if k in self.previous_args and v(self.previous_args[k]):
del self.previous_args[k]

previous_args = {
k: v for k, v in self.previous_args.items() if k not in OVERRIDE_PARENT_FLAGS
k: v for k, v in self.previous_args.items() if k not in IGNORE_PARENT_FLAGS
}
click_context = get_current_context()
current_args = {
k: v
for k, v in args.__dict__.items()
if k in IGNORE_PARENT_FLAGS
or (
click_context.get_parameter_source(k) == ParameterSource.COMMANDLINE
and k in ALLOW_OVERRIDE_FLAGS
)
}
current_args = {k: v for k, v in self.args.__dict__.items() if k in OVERRIDE_PARENT_FLAGS}
combined_args = {**previous_args, **current_args}

retry_flags = Flags.from_dict(cli_command, combined_args)
retry_flags = Flags.from_dict(cli_command, combined_args) # type: ignore
set_flags(retry_flags)
retry_config = RuntimeConfig.from_args(args=retry_flags)

# Parse manifest using resolved config/flags
manifest = parse_manifest(retry_config, False, True, retry_flags.write_json) # type: ignore
super().__init__(args, retry_config, manifest)
self.task_class = TASK_DICT.get(self.previous_command_name) # type: ignore

def run(self):
unique_ids = set(
[
result.unique_id
for result in self.previous_results.results
if result.status in RETRYABLE_STATUSES
]
)

class TaskWrapper(self.task_class):
def get_graph_queue(self):
new_graph = self.graph.get_subset_graph(unique_ids)
Expand All @@ -120,8 +132,8 @@ def get_graph_queue(self):
)

task = TaskWrapper(
retry_flags,
retry_config,
get_flags(),
self.config,
self.manifest,
)

Expand Down
Loading