diff --git a/.changes/unreleased/Features-20240818-005131.yaml b/.changes/unreleased/Features-20240818-005131.yaml new file mode 100644 index 00000000..17260206 --- /dev/null +++ b/.changes/unreleased/Features-20240818-005131.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add Behavior Flag framework +time: 2024-08-18T00:51:31.753656-04:00 +custom: + Author: mikealfare + Issue: "281" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index caf34209..0f2a03f7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: - --max-line-length=99 - --select=E,F,W - --ignore=E203,E501,E704,E741,W503,W504 - - --per-file-ignores=*/__init__.py:F401 + - --per-file-ignores=*/__init__.py:F401,*/conftest.py:F401 - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.9.0 diff --git a/dbt/adapters/base/impl.py b/dbt/adapters/base/impl.py index e0627c47..541c9846 100644 --- a/dbt/adapters/base/impl.py +++ b/dbt/adapters/base/impl.py @@ -24,6 +24,7 @@ ) import pytz +from dbt_common.behavior_flags import Behavior, BehaviorFlag from dbt_common.clients.jinja import CallableMacroGenerator from dbt_common.contracts.constraints import ( ColumnLevelConstraint, @@ -261,7 +262,7 @@ class BaseAdapter(metaclass=AdapterMeta): MAX_SCHEMA_METADATA_RELATIONS = 100 - # This static member variable can be overriden in concrete adapter + # This static member variable can be overridden in concrete adapter # implementations to indicate adapter support for optional capabilities. _capabilities = CapabilityDict({}) @@ -271,6 +272,7 @@ def __init__(self, config, mp_context: SpawnContext) -> None: self.connections = self.ConnectionManager(config, mp_context) self._macro_resolver: Optional[MacroResolverProtocol] = None self._macro_context_generator: Optional[MacroContextGeneratorCallable] = None + self.behavior = [] # this will be updated to include global behavior flags once they exist ### # Methods to set / access a macro resolver @@ -291,6 +293,27 @@ def set_macro_context_generator( ) -> None: self._macro_context_generator = macro_context_generator + @property + def behavior(self) -> Behavior: + return self._behavior + + @behavior.setter + def behavior(self, flags: List[BehaviorFlag]) -> None: + flags.extend(self._behavior_flags) + try: + # we don't always get project flags, for example during `dbt debug` + self._behavior = Behavior(flags, self.config.flags) + except AttributeError: + # in that case, don't load any behavior to avoid unexpected defaults + self._behavior = Behavior([], {}) + + @property + def _behavior_flags(self) -> List[BehaviorFlag]: + """ + This method should be overwritten by adapter maintainers to provide platform-specific flags + """ + return [] + ### # Methods that pass through to the connection manager ### diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 00000000..346634df --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1 @@ +from tests.unit.fixtures import adapter, behavior_flags, config, flags diff --git a/tests/unit/fixtures/__init__.py b/tests/unit/fixtures/__init__.py new file mode 100644 index 00000000..78135a2c --- /dev/null +++ b/tests/unit/fixtures/__init__.py @@ -0,0 +1 @@ +from tests.unit.fixtures.adapter import adapter, behavior_flags, config, flags diff --git a/tests/unit/fixtures/adapter.py b/tests/unit/fixtures/adapter.py new file mode 100644 index 00000000..b59b0423 --- /dev/null +++ b/tests/unit/fixtures/adapter.py @@ -0,0 +1,146 @@ +from multiprocessing import get_context +from types import SimpleNamespace +from typing import Any, Dict, List + +import agate +from dbt_common.behavior_flags import BehaviorFlag +import pytest + +from dbt.adapters.base.column import Column +from dbt.adapters.base.impl import BaseAdapter +from dbt.adapters.base.relation import BaseRelation +from dbt.adapters.contracts.connection import AdapterRequiredConfig, QueryComment + +from tests.unit.fixtures.connection_manager import ConnectionManagerStub +from tests.unit.fixtures.credentials import CredentialsStub + + +@pytest.fixture +def adapter(config, behavior_flags) -> BaseAdapter: + + class BaseAdapterStub(BaseAdapter): + """ + A stub for an adapter that uses the cache as the database + """ + + ConnectionManager = ConnectionManagerStub + + @property + def _behavior_flags(self) -> List[BehaviorFlag]: + return behavior_flags + + ### + # Abstract methods for database-specific values, attributes, and types + ### + @classmethod + def date_function(cls) -> str: + return "date_function" + + @classmethod + def is_cancelable(cls) -> bool: + return False + + def list_schemas(self, database: str) -> List[str]: + return list(self.cache.schemas) + + ### + # Abstract methods about relations + ### + def drop_relation(self, relation: BaseRelation) -> None: + self.cache_dropped(relation) + + def truncate_relation(self, relation: BaseRelation) -> None: + self.cache_dropped(relation) + + def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None: + self.cache_renamed(from_relation, to_relation) + + def get_columns_in_relation(self, relation: BaseRelation) -> List[Column]: + # there's no database, so these need to be added as kwargs in the existing_relations fixture + return relation.columns + + def expand_column_types(self, goal: BaseRelation, current: BaseRelation) -> None: + # there's no database, so these need to be added as kwargs in the existing_relations fixture + object.__setattr__(current, "columns", goal.columns) + + def list_relations_without_caching( + self, schema_relation: BaseRelation + ) -> List[BaseRelation]: + # there's no database, so use the cache as the database + return self.cache.get_relations(schema_relation.database, schema_relation.schema) + + ### + # ODBC FUNCTIONS -- these should not need to change for every adapter, + # although some adapters may override them + ### + def create_schema(self, relation: BaseRelation): + # there's no database, this happens implicitly by adding a relation to the cache + pass + + def drop_schema(self, relation: BaseRelation): + for each_relation in self.cache.get_relations(relation.database, relation.schema): + self.cache_dropped(each_relation) + + @classmethod + def quote(cls, identifier: str) -> str: + quote_char = "" + return f"{quote_char}{identifier}{quote_char}" + + ### + # Conversions: These must be implemented by concrete implementations, for + # converting agate types into their sql equivalents. + ### + @classmethod + def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str: + return "str" + + @classmethod + def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: + return "float" + + @classmethod + def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str: + return "bool" + + @classmethod + def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: + return "datetime" + + @classmethod + def convert_date_type(cls, *args, **kwargs): + return "date" + + @classmethod + def convert_time_type(cls, *args, **kwargs): + return "time" + + return BaseAdapterStub(config, get_context("spawn")) + + +@pytest.fixture +def config(flags) -> AdapterRequiredConfig: + raw_config = { + "credentials": CredentialsStub("test_database", "test_schema"), + "profile_name": "test_profile", + "target_name": "test_target", + "threads": 4, + "project_name": "test_project", + "query_comment": QueryComment(), + "cli_vars": {}, + "target_path": "path/to/nowhere", + "log_cache_events": False, + "flags": flags, + } + return SimpleNamespace(**raw_config) + + +@pytest.fixture +def flags() -> Dict[str, Any]: + # this is the flags collection in dbt_project.yaml + return {} + + +@pytest.fixture +def behavior_flags() -> List[BehaviorFlag]: + # this is the collection of behavior flags for a specific adapter + return [] diff --git a/tests/unit/fixtures/connection_manager.py b/tests/unit/fixtures/connection_manager.py new file mode 100644 index 00000000..8b353fbe --- /dev/null +++ b/tests/unit/fixtures/connection_manager.py @@ -0,0 +1,58 @@ +from contextlib import contextmanager +from typing import ContextManager, List, Optional, Tuple + +import agate + +from dbt.adapters.base.connections import BaseConnectionManager +from dbt.adapters.contracts.connection import AdapterResponse, Connection, ConnectionState + + +class ConnectionManagerStub(BaseConnectionManager): + """ + A stub for a connection manager that does not connect to a database + """ + + raised_exceptions: List[Exception] + + @contextmanager + def exception_handler(self, sql: str) -> ContextManager: # type: ignore + # catch all exceptions and put them on this class for inspection in tests + try: + yield + except Exception as exc: + self.raised_exceptions.append(exc) + finally: + pass + + def cancel_open(self) -> Optional[List[str]]: + names = [] + for connection in self.thread_connections.values(): + if connection.state == ConnectionState.OPEN: + connection.state = ConnectionState.CLOSED + if name := connection.name: + names.append(name) + return names + + @classmethod + def open(cls, connection: Connection) -> Connection: + # there's no database, so just change the state + connection.state = ConnectionState.OPEN + return connection + + def begin(self) -> None: + # there's no database, so there are no transactions + pass + + def commit(self) -> None: + # there's no database, so there are no transactions + pass + + def execute( + self, + sql: str, + auto_begin: bool = False, + fetch: bool = False, + limit: Optional[int] = None, + ) -> Tuple[AdapterResponse, agate.Table]: + # there's no database, so just return the sql + return AdapterResponse(_message="", code=sql), agate.Table([]) diff --git a/tests/unit/fixtures/credentials.py b/tests/unit/fixtures/credentials.py new file mode 100644 index 00000000..88817f6b --- /dev/null +++ b/tests/unit/fixtures/credentials.py @@ -0,0 +1,13 @@ +from dbt.adapters.contracts.connection import Credentials + + +class CredentialsStub(Credentials): + """ + A stub for a database credentials that does not connect to a database + """ + + def type(self) -> str: + return "test" + + def _connection_keys(self): + return {"database": self.database, "schema": self.schema} diff --git a/tests/unit/test_behavior_flags.py b/tests/unit/test_behavior_flags.py new file mode 100644 index 00000000..0ae1a021 --- /dev/null +++ b/tests/unit/test_behavior_flags.py @@ -0,0 +1,42 @@ +from typing import Any, Dict, List + +from dbt_common.behavior_flags import BehaviorFlag +from dbt_common.exceptions import DbtBaseException +import pytest + + +@pytest.fixture +def flags() -> Dict[str, Any]: + return { + "unregistered_flag": True, + "default_false_user_false_flag": False, + "default_false_user_true_flag": True, + "default_true_user_false_flag": False, + "default_true_user_true_flag": True, + } + + +@pytest.fixture +def behavior_flags() -> List[BehaviorFlag]: + return [ + {"name": "default_false_user_false_flag", "default": False}, + {"name": "default_false_user_true_flag", "default": False}, + {"name": "default_false_user_skip_flag", "default": False}, + {"name": "default_true_user_false_flag", "default": True}, + {"name": "default_true_user_true_flag", "default": True}, + {"name": "default_true_user_skip_flag", "default": True}, + ] + + +def test_register_behavior_flags(adapter): + # make sure that users cannot add arbitrary flags to this collection + with pytest.raises(DbtBaseException): + assert adapter.behavior.unregistered_flag + + # check the values of the valid behavior flags + assert not adapter.behavior.default_false_user_false_flag + assert adapter.behavior.default_false_user_true_flag + assert not adapter.behavior.default_false_user_skip_flag + assert not adapter.behavior.default_true_user_false_flag + assert adapter.behavior.default_true_user_true_flag + assert adapter.behavior.default_true_user_skip_flag