diff --git a/.changes/unreleased/Under the Hood-20240104-165248.yaml b/.changes/unreleased/Under the Hood-20240104-165248.yaml new file mode 100644 index 00000000000..867107a54a8 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240104-165248.yaml @@ -0,0 +1,7 @@ +kind: Under the Hood +body: Accept valid_error_names in WarnErrorOptions constructor, remove global usage + of event modules +time: 2024-01-04T16:52:48.173716-05:00 +custom: + Author: michelleark + Issue: "9337" diff --git a/core/dbt/cli/flags.py b/core/dbt/cli/flags.py index 4324f59fee7..580bddcd372 100644 --- a/core/dbt/cli/flags.py +++ b/core/dbt/cli/flags.py @@ -19,6 +19,7 @@ from dbt.common.clients import jinja from dbt.deprecations import renamed_env_var from dbt.common.helper_types import WarnErrorOptions +from dbt.events import ALL_EVENT_NAMES if os.name != "nt": # https://bugs.python.org/issue41567 @@ -50,7 +51,9 @@ def convert_config(config_name, config_value): ret = config_value if config_name.lower() == "warn_error_options" and type(config_value) == dict: ret = WarnErrorOptions( - include=config_value.get("include", []), exclude=config_value.get("exclude", []) + include=config_value.get("include", []), + exclude=config_value.get("exclude", []), + valid_error_names=ALL_EVENT_NAMES, ) return ret diff --git a/core/dbt/cli/option_types.py b/core/dbt/cli/option_types.py index f56740161be..f673ac279dc 100644 --- a/core/dbt/cli/option_types.py +++ b/core/dbt/cli/option_types.py @@ -1,6 +1,7 @@ from click import ParamType, Choice from dbt.config.utils import parse_cli_yaml_string +from dbt.events import ALL_EVENT_NAMES from dbt.exceptions import ValidationError, OptionNotYamlDictError from dbt.common.exceptions import DbtValidationError @@ -53,7 +54,9 @@ def convert(self, value, param, ctx): include_exclude = super().convert(value, param, ctx) return WarnErrorOptions( - include=include_exclude.get("include", []), exclude=include_exclude.get("exclude", []) + include=include_exclude.get("include", []), + exclude=include_exclude.get("exclude", []), + valid_error_names=ALL_EVENT_NAMES, ) diff --git a/core/dbt/common/helper_types.py b/core/dbt/common/helper_types.py index 2b802e2b1c5..457b02224e5 100644 --- a/core/dbt/common/helper_types.py +++ b/core/dbt/common/helper_types.py @@ -5,17 +5,13 @@ from dataclasses import dataclass, field from typing import Tuple, AbstractSet, Union -from typing import Callable, cast, Generic, Optional, TypeVar, List, NewType, Any, Dict +from typing import Callable, cast, Generic, Optional, TypeVar, List, NewType, Set from dbt.common.dataclass_schema import ( dbtClassMixin, ValidationError, StrEnum, ) -import dbt.adapters.events.types as adapter_dbt_event_types -import dbt.common.events.types as dbt_event_types -import dbt.events.types as core_dbt_event_types - Port = NewType("Port", int) @@ -68,18 +64,18 @@ def _validate_items(self, items: List[str]): class WarnErrorOptions(IncludeExclude): - def _validate_items(self, items: List[str]): - all_event_types: Dict[str, Any] = { - **dbt_event_types.__dict__, - **core_dbt_event_types.__dict__, - **adapter_dbt_event_types.__dict__, - } - valid_exception_names = set( - [name for name, cls in all_event_types.items() if isinstance(cls, type)] - ) + def __init__( + self, + include: Union[str, List[str]], + exclude: Optional[List[str]] = None, + valid_error_names: Optional[Set[str]] = None, + ): + self._valid_error_names: Set[str] = valid_error_names or set() + super().__init__(include=include, exclude=(exclude or [])) + def _validate_items(self, items: List[str]): for item in items: - if item not in valid_exception_names: + if item not in self._valid_error_names: raise ValidationError(f"{item} is not a valid dbt error name.") diff --git a/core/dbt/events/__init__.py b/core/dbt/events/__init__.py index e69de29bb2d..36da848965c 100644 --- a/core/dbt/events/__init__.py +++ b/core/dbt/events/__init__.py @@ -0,0 +1,15 @@ +from typing import Dict, Any, Set + +import dbt.adapters.events.types as adapter_dbt_event_types +import dbt.common.events.types as dbt_event_types +import dbt.events.types as core_dbt_event_types + +ALL_EVENT_TYPES: Dict[str, Any] = { + **dbt_event_types.__dict__, + **core_dbt_event_types.__dict__, + **adapter_dbt_event_types.__dict__, +} + +ALL_EVENT_NAMES: Set[str] = set( + [name for name, cls in ALL_EVENT_TYPES.items() if isinstance(cls, type)] +) diff --git a/tests/unit/test_functions.py b/tests/unit/test_functions.py index 53ba7a5f942..e6572d3cbea 100644 --- a/tests/unit/test_functions.py +++ b/tests/unit/test_functions.py @@ -5,8 +5,10 @@ from dbt.common.events.functions import msg_to_dict, warn_or_error from dbt.events.logging import setup_event_logger from dbt.common.events.types import InfoLevel -from dbt.events.types import NoNodesForSelectionCriteria from dbt.common.exceptions import EventCompilationError +from dbt.events.types import NoNodesForSelectionCriteria +from dbt.adapters.events.types import AdapterDeprecationWarning +from dbt.common.events.types import RetryExternalCall @pytest.mark.parametrize( @@ -30,6 +32,25 @@ def test_warn_or_error_warn_error_options(warn_error_options, expect_compilation warn_or_error(NoNodesForSelectionCriteria()) +@pytest.mark.parametrize( + "error_cls", + [ + NoNodesForSelectionCriteria, # core event + AdapterDeprecationWarning, # adapter event + RetryExternalCall, # common event + ], +) +def test_warn_error_options_captures_all_events(error_cls): + args = Namespace(warn_error_options={"include": [error_cls.__name__]}) + flags.set_from_args(args, {}) + with pytest.raises(EventCompilationError): + warn_or_error(error_cls()) + + args = Namespace(warn_error_options={"include": "*", "exclude": [error_cls.__name__]}) + flags.set_from_args(args, {}) + warn_or_error(error_cls()) + + @pytest.mark.parametrize( "warn_error,expect_compilation_exception", [ diff --git a/tests/unit/test_helper_types.py b/tests/unit/test_helper_types.py index f65ea0c0976..59b05d90ba9 100644 --- a/tests/unit/test_helper_types.py +++ b/tests/unit/test_helper_types.py @@ -30,25 +30,30 @@ def test_includes(self, include, exclude, expected_includes): class TestWarnErrorOptions: def test_init_invalid_error(self): + with pytest.raises(ValidationError): + WarnErrorOptions(include=["InvalidError"], valid_error_names=set(["ValidError"])) + + with pytest.raises(ValidationError): + WarnErrorOptions( + include="*", exclude=["InvalidError"], valid_error_names=set(["ValidError"]) + ) + + def test_init_invalid_error_default_valid_error_names(self): with pytest.raises(ValidationError): WarnErrorOptions(include=["InvalidError"]) with pytest.raises(ValidationError): WarnErrorOptions(include="*", exclude=["InvalidError"]) - @pytest.mark.parametrize( - "valid_error_name", - [ - "NoNodesForSelectionCriteria", # core event - "AdapterDeprecationWarning", # adapter event - "RetryExternalCall", # common event - ], - ) - def test_init_valid_error(self, valid_error_name): - warn_error_options = WarnErrorOptions(include=[valid_error_name]) - assert warn_error_options.include == [valid_error_name] + def test_init_valid_error(self): + warn_error_options = WarnErrorOptions( + include=["ValidError"], valid_error_names=set(["ValidError"]) + ) + assert warn_error_options.include == ["ValidError"] assert warn_error_options.exclude == [] - warn_error_options = WarnErrorOptions(include="*", exclude=[valid_error_name]) + warn_error_options = WarnErrorOptions( + include="*", exclude=["ValidError"], valid_error_names=set(["ValidError"]) + ) assert warn_error_options.include == "*" - assert warn_error_options.exclude == [valid_error_name] + assert warn_error_options.exclude == ["ValidError"]