Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

consolidate flags #6788

Merged
merged 17 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions core/dbt/adapters/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from dbt.events.functions import fire_event, fire_event_if
from dbt.events.types import CacheAction, CacheDumpGraph
import dbt.flags as flags
from dbt.flags import get_flags
from dbt.utils import lowercase


Expand Down Expand Up @@ -319,6 +319,7 @@ def add(self, relation):

:param BaseRelation relation: The underlying relation.
"""
flags = get_flags()
cached = _CachedRelation(relation)
fire_event_if(
flags.LOG_CACHE_EVENTS,
Expand Down Expand Up @@ -456,7 +457,7 @@ def rename(self, old, new):
ref_key_2=_make_msg_from_ref_key(new),
)
)

flags = get_flags()
fire_event_if(
flags.LOG_CACHE_EVENTS,
lambda: CacheDumpGraph(before_after="before", action="rename", dump=self.dump_graph()),
Expand Down
89 changes: 83 additions & 6 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,85 @@

from dbt.config.profile import read_user_config
from dbt.contracts.project import UserConfig
from dbt.helper_types import WarnErrorOptions
from dbt.config.project import PartialProject
from dbt.exceptions import DbtProjectError

if os.name != "nt":
# https://bugs.python.org/issue41567
import multiprocessing.popen_spawn_posix # type: ignore # noqa: F401

# TODO anything that has a default in params should be removed here?
# Or maybe only the ones that's in the root click group
FLAGS_DEFAULTS = {
"INDIRECT_SELECTION": "eager",
"TARGET_PATH": None,
# cli args without user_config or env var option
"FULL_REFRESH": False,
"STRICT_MODE": False,
"STORE_FAILURES": False,
}


# For backwards compatability, some params are defined across multiple levels,
# Top-level value should take precedence.
# e.g. dbt --target-path test2 run --target-path test2
EXPECTED_DUPLICATE_PARAMS = [
"full_refresh",
"target_path",
"version_check",
"fail_fast",
"indirect_selection",
"store_failures",
]


def convert_config(config_name, config_value):
# This function should take care of converting the values from config and original
# set_from_args to the correct type
ret = config_value
if config_name.lower() == "warn_error_options":
ret = WarnErrorOptions(
include=config_value.get("include", []), exclude=config_value.get("exclude", [])
)
return ret


@dataclass(frozen=True)
class Flags:
def __init__(self, ctx: Context = None, user_config: UserConfig = None) -> None:

# set the default flags
for key, value in FLAGS_DEFAULTS.items():
object.__setattr__(self, key, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not directly related to this changeset, so feel free to ignore this comment.

I think this flags code would be simpler and more readable if we added _get_flag() and _set_flag() functions which handled the casing concerns and the use of the setattr/getattr functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was chatting with @MichelleArk about similar thing also. @iknox-fa might have something around it, but we are gonna defer that to work later on instead of in this PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't do that the first time around because I wanted to keep things as immutable as python allows so adding even a "private" method to do that seemed like a recipe for mutated flags in the codebase in a few months.

Maybe as a compromise for readability we just make nested get and set functions in init? That way no-one uses them unless they're directly editing the flags? We can discuss it further but given how much of a tangled mess the flags are currently I think we're going to get a lot of mileage out of having reasonable assurance of immutability and I don't want to jeopardize that.


if ctx is None:
ctx = get_current_context()

def assign_params(ctx, params_assigned_from_default):
"""Recursively adds all click params to flag object"""
for param_name, param_value in ctx.params.items():
# TODO: this is to avoid duplicate params being defined in two places (version_check in run and cli)
# However this is a bit of a hack and we should find a better way to do this

# N.B. You have to use the base MRO method (object.__setattr__) to set attributes
# when using frozen dataclasses.
# https://docs.python.org/3/library/dataclasses.html#frozen-instances
if hasattr(self, param_name):
raise Exception(f"Duplicate flag names found in click command: {param_name}")
object.__setattr__(self, param_name.upper(), param_value)
if ctx.get_parameter_source(param_name) == ParameterSource.DEFAULT:
params_assigned_from_default.add(param_name)
if hasattr(self, param_name.upper()):
if param_name not in EXPECTED_DUPLICATE_PARAMS:
raise Exception(
f"Duplicate flag names found in click command: {param_name}"
)
else:
# Expected duplicate param from multi-level click command (ex: dbt --full_refresh run --full_refresh)
# Overwrite user-configured param with value from parent context
if ctx.get_parameter_source(param_name) != ParameterSource.DEFAULT:
object.__setattr__(self, param_name.upper(), param_value)
else:
object.__setattr__(self, param_name.upper(), param_value)
if ctx.get_parameter_source(param_name) == ParameterSource.DEFAULT:
params_assigned_from_default.add(param_name)

if ctx.parent:
assign_params(ctx.parent, params_assigned_from_default)

Expand Down Expand Up @@ -64,7 +119,9 @@ def assign_params(ctx, params_assigned_from_default):
user_config_param_value = getattr(user_config, param_assigned_from_default, None)
if user_config_param_value is not None:
object.__setattr__(
self, param_assigned_from_default.upper(), user_config_param_value
self,
param_assigned_from_default.upper(),
convert_config(param_assigned_from_default, user_config_param_value),
)
param_assigned_from_default_copy.remove(param_assigned_from_default)
params_assigned_from_default = param_assigned_from_default_copy
Expand All @@ -73,6 +130,26 @@ def assign_params(ctx, params_assigned_from_default):
object.__setattr__(self, "WHICH", invoked_subcommand_name or ctx.info_name)
object.__setattr__(self, "MP_CONTEXT", get_context("spawn"))

# Default LOG_PATH from PROJECT_DIR, if available.
if getattr(self, "LOG_PATH", None) is None:
log_path = "logs"
project_dir = getattr(self, "PROJECT_DIR", None)
# If available, set LOG_PATH from log-path in dbt_project.yml
# Known limitations:
# 1. Using PartialProject here, so no jinja rendering of log-path.
# 2. Programmatic invocations of the cli via dbtRunner may pass a Project object directly,
# which is not being used here to extract log-path.
if project_dir:
try:
partial = PartialProject.from_project_root(
project_dir, verify_version=getattr(self, "VERSION_CHECK", True)
)
log_path = str(partial.project_dict.get("log-path", log_path))
except DbtProjectError:
pass

object.__setattr__(self, "LOG_PATH", log_path)

# Support console DO NOT TRACK initiave
object.__setattr__(
self,
Expand Down
1 change: 0 additions & 1 deletion core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,6 @@ def seed(ctx, **kwargs):
ctx.obj["runtime_config"],
ctx.obj["manifest"],
)

results = task.run()
success = task.interpret_results(results)
return results, success
Expand Down
1 change: 1 addition & 0 deletions core/dbt/cli/option_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class WarnErrorOptionsType(YAML):
name = "WarnErrorOptionsType"

def convert(self, value, param, ctx):
# this function is being used by param in click
include_exclude = super().convert(value, param, ctx)

return WarnErrorOptions(
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
"--log-path",
envvar="DBT_LOG_PATH",
help="Configure the 'log-path'. Only applies this setting for the current run. Overrides the 'DBT_LOG_PATH' if it is set.",
default=lambda: Path.cwd() / "logs",
default=None,
type=click.Path(resolve_path=True, path_type=Path),
)

Expand Down Expand Up @@ -415,7 +415,7 @@ def _version_callback(ctx, _param, value):
warn_error_options = click.option(
"--warn-error-options",
envvar="DBT_WARN_ERROR_OPTIONS",
default=None,
default="{}",
help="""If dbt would normally warn, instead raise an exception based on include/exclude configuration. Examples include --select that selects nothing, deprecations, configurations with no associated models, invalid test configurations,
and missing sources/refs in tests. This argument should be a YAML string, with keys 'include' or 'exclude'. eg. '{"include": "all", "exclude": ["NoNodesForSelectionCriteria"]}'""",
type=WarnErrorOptionsType(),
Expand Down
2 changes: 2 additions & 0 deletions core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dbt.adapters.factory import adapter_management, register_adapter
from dbt.flags import set_flags
from dbt.cli.flags import Flags
from dbt.config import RuntimeConfig
from dbt.config.runtime import load_project, load_profile
Expand All @@ -21,6 +22,7 @@ def wrapper(*args, **kwargs):
# Flags
flags = Flags(ctx)
ctx.obj["flags"] = flags
set_flags(flags)

# Tracking
initialize_from_flags(flags.ANONYMOUS_USAGE_STATS, flags.PROFILES_DIR)
Expand Down
7 changes: 4 additions & 3 deletions core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
UndefinedCompilationError,
UndefinedMacroError,
)
from dbt import flags
from dbt.flags import get_flags
from dbt.node_types import ModelLanguage


Expand Down Expand Up @@ -99,8 +99,9 @@ def _compile(self, source, filename):
If the value is 'write', also write the files to disk.
WARNING: This can write a ton of data if you aren't careful.
"""
if filename == "<template>" and flags.MACRO_DEBUGGING:
write = flags.MACRO_DEBUGGING == "write"
macro_debugging = get_flags().MACRO_DEBUGGING
if filename == "<template>" and macro_debugging:
write = macro_debugging == "write"
filename = _linecache_inject(source, write)

return super()._compile(source, filename) # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import defaultdict
from typing import List, Dict, Any, Tuple, Optional

from dbt import flags
from dbt.flags import get_flags
from dbt.adapters.factory import get_adapter
from dbt.clients import jinja
from dbt.clients.system import make_directory
Expand Down Expand Up @@ -378,6 +378,7 @@ def _compile_node(
def write_graph_file(self, linker: Linker, manifest: Manifest):
filename = graph_file_name
graph_path = os.path.join(self.config.target_path, filename)
flags = get_flags()
if flags.WRITE_JSON:
linker.write_graph(graph_path, manifest)

Expand Down
43 changes: 25 additions & 18 deletions core/dbt/config/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dbt.dataclass_schema import ValidationError

from dbt import flags
from dbt.flags import get_flags
from dbt.clients.system import load_file_contents
from dbt.clients.yaml_helper import load_yaml_text
from dbt.contracts.connection import Credentials, HasCredentials
Expand Down Expand Up @@ -32,22 +32,6 @@
"""


NO_SUPPLIED_PROFILE_ERROR = """\
dbt cannot run because no profile was specified for this dbt project.
To specify a profile for this project, add a line like the this to
your dbt_project.yml file:

profile: [profile name]

Here, [profile name] should be replaced with a profile name
defined in your profiles.yml file. You can find profiles.yml here:

{profiles_file}/profiles.yml
""".format(
profiles_file=flags.DEFAULT_PROFILES_DIR
)


def read_profile(profiles_dir: str) -> Dict[str, Any]:
path = os.path.join(profiles_dir, "profiles.yml")

Expand Down Expand Up @@ -197,10 +181,33 @@ def pick_profile_name(
args_profile_name: Optional[str],
project_profile_name: Optional[str] = None,
) -> str:
# TODO: Duplicating this method as direct copy of the implementation in dbt.cli.resolvers
# dbt.cli.resolvers implementation can't be used because it causes a circular dependency.
# This should be removed and use a safe default access on the Flags module when
# https://github.com/dbt-labs/dbt-core/issues/6259 is closed.
def default_profiles_dir():
Copy link
Contributor

@MichelleArk MichelleArk Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: add comment + link issue for default lookup in Flags

from pathlib import Path

return Path.cwd() if (Path.cwd() / "profiles.yml").exists() else Path.home() / ".dbt"

profile_name = project_profile_name
if args_profile_name is not None:
profile_name = args_profile_name
if profile_name is None:
NO_SUPPLIED_PROFILE_ERROR = """\
dbt cannot run because no profile was specified for this dbt project.
To specify a profile for this project, add a line like the this to
your dbt_project.yml file:

profile: [profile name]

Here, [profile name] should be replaced with a profile name
defined in your profiles.yml file. You can find profiles.yml here:

{profiles_file}/profiles.yml
""".format(
profiles_file=default_profiles_dir()
)
raise DbtProjectError(NO_SUPPLIED_PROFILE_ERROR)
return profile_name

Expand Down Expand Up @@ -423,7 +430,7 @@ def render(
target could not be found.
:returns Profile: The new Profile object.
"""

flags = get_flags()
raw_profiles = read_profile(flags.PROFILES_DIR)
profile_name = cls.pick_profile_name(profile_name_override, project_profile_name)
return cls.from_raw_profiles(
Expand Down
11 changes: 8 additions & 3 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import hashlib
import os

from dbt import flags, deprecations
from dbt.flags import get_flags
from dbt import deprecations
from dbt.clients.system import path_exists, resolve_path_from_base, load_file_contents
from dbt.clients.yaml_helper import load_yaml_text
from dbt.contracts.connection import QueryComment
Expand Down Expand Up @@ -373,9 +374,13 @@ def create_project(self, rendered: RenderComponents) -> "Project":

docs_paths: List[str] = value_or(cfg.docs_paths, all_source_paths)
asset_paths: List[str] = value_or(cfg.asset_paths, [])
target_path: str = flag_or(flags.TARGET_PATH, cfg.target_path, "target")
flags = get_flags()

flag_target_path = str(flags.TARGET_PATH) if flags.TARGET_PATH else None
target_path: str = flag_or(flag_target_path, cfg.target_path, "target")

log_path: str = str(flags.LOG_PATH)
clean_targets: List[str] = value_or(cfg.clean_targets, [target_path])
log_path: str = flag_or(flags.LOG_PATH, cfg.log_path, "logs")
packages_install_path: str = value_or(cfg.packages_install_path, "dbt_packages")
# in the default case we'll populate this once we know the adapter type
# It would be nice to just pass along a Quoting here, but that would
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Type,
)

from dbt import flags
from dbt.flags import get_flags
from dbt.adapters.factory import get_include_paths, get_relation_class_by_name
from dbt.config.project import load_raw_project
from dbt.contracts.connection import AdapterRequiredConfig, Credentials, HasCredentials
Expand Down Expand Up @@ -197,11 +197,10 @@ def new_project(self, project_root: str) -> "RuntimeConfig":

# load the new project and its packages. Don't pass cli variables.
renderer = DbtProjectYamlRenderer(profile)

project = Project.from_project_root(
project_root,
renderer,
verify_version=bool(flags.VERSION_CHECK),
verify_version=bool(getattr(self.args, "VERSION_CHECK", True)),
)

runtime_config = self.from_parts(
Expand Down Expand Up @@ -247,6 +246,7 @@ def collect_parts(cls: Type["RuntimeConfig"], args: Any) -> Tuple[Project, Profi
cli_vars,
args,
)
flags = get_flags()
project = load_project(project_root, bool(flags.VERSION_CHECK), profile, cli_vars)
return project, profile

Expand Down
Loading