Skip to content

Commit

Permalink
Add Behavior Flag Framework (#282)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikealfare authored Sep 6, 2024
1 parent 7de2e87 commit 8a570fa
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 2 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240818-005131.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add Behavior Flag framework
time: 2024-08-18T00:51:31.753656-04:00
custom:
Author: mikealfare
Issue: "281"
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ repos:
- --max-line-length=99
- --select=E,F,W
- --ignore=E203,E501,E704,E741,W503,W504
- --per-file-ignores=*/__init__.py:F401
- --per-file-ignores=*/__init__.py:F401,*/conftest.py:F401

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
Expand Down
25 changes: 24 additions & 1 deletion dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)

import pytz
from dbt_common.behavior_flags import Behavior, BehaviorFlag
from dbt_common.clients.jinja import CallableMacroGenerator
from dbt_common.contracts.constraints import (
ColumnLevelConstraint,
Expand Down Expand Up @@ -261,7 +262,7 @@ class BaseAdapter(metaclass=AdapterMeta):

MAX_SCHEMA_METADATA_RELATIONS = 100

# This static member variable can be overriden in concrete adapter
# This static member variable can be overridden in concrete adapter
# implementations to indicate adapter support for optional capabilities.
_capabilities = CapabilityDict({})

Expand All @@ -271,6 +272,7 @@ def __init__(self, config, mp_context: SpawnContext) -> None:
self.connections = self.ConnectionManager(config, mp_context)
self._macro_resolver: Optional[MacroResolverProtocol] = None
self._macro_context_generator: Optional[MacroContextGeneratorCallable] = None
self.behavior = [] # this will be updated to include global behavior flags once they exist

###
# Methods to set / access a macro resolver
Expand All @@ -291,6 +293,27 @@ def set_macro_context_generator(
) -> None:
self._macro_context_generator = macro_context_generator

@property
def behavior(self) -> Behavior:
return self._behavior

@behavior.setter
def behavior(self, flags: List[BehaviorFlag]) -> None:
flags.extend(self._behavior_flags)
try:
# we don't always get project flags, for example during `dbt debug`
self._behavior = Behavior(flags, self.config.flags)
except AttributeError:
# in that case, don't load any behavior to avoid unexpected defaults
self._behavior = Behavior([], {})

@property
def _behavior_flags(self) -> List[BehaviorFlag]:
"""
This method should be overwritten by adapter maintainers to provide platform-specific flags
"""
return []

###
# Methods that pass through to the connection manager
###
Expand Down
1 change: 1 addition & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from tests.unit.fixtures import adapter, behavior_flags, config, flags
1 change: 1 addition & 0 deletions tests/unit/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from tests.unit.fixtures.adapter import adapter, behavior_flags, config, flags
146 changes: 146 additions & 0 deletions tests/unit/fixtures/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from multiprocessing import get_context
from types import SimpleNamespace
from typing import Any, Dict, List

import agate
from dbt_common.behavior_flags import BehaviorFlag
import pytest

from dbt.adapters.base.column import Column
from dbt.adapters.base.impl import BaseAdapter
from dbt.adapters.base.relation import BaseRelation
from dbt.adapters.contracts.connection import AdapterRequiredConfig, QueryComment

from tests.unit.fixtures.connection_manager import ConnectionManagerStub
from tests.unit.fixtures.credentials import CredentialsStub


@pytest.fixture
def adapter(config, behavior_flags) -> BaseAdapter:

class BaseAdapterStub(BaseAdapter):
"""
A stub for an adapter that uses the cache as the database
"""

ConnectionManager = ConnectionManagerStub

@property
def _behavior_flags(self) -> List[BehaviorFlag]:
return behavior_flags

###
# Abstract methods for database-specific values, attributes, and types
###
@classmethod
def date_function(cls) -> str:
return "date_function"

@classmethod
def is_cancelable(cls) -> bool:
return False

def list_schemas(self, database: str) -> List[str]:
return list(self.cache.schemas)

###
# Abstract methods about relations
###
def drop_relation(self, relation: BaseRelation) -> None:
self.cache_dropped(relation)

def truncate_relation(self, relation: BaseRelation) -> None:
self.cache_dropped(relation)

def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None:
self.cache_renamed(from_relation, to_relation)

def get_columns_in_relation(self, relation: BaseRelation) -> List[Column]:
# there's no database, so these need to be added as kwargs in the existing_relations fixture
return relation.columns

def expand_column_types(self, goal: BaseRelation, current: BaseRelation) -> None:
# there's no database, so these need to be added as kwargs in the existing_relations fixture
object.__setattr__(current, "columns", goal.columns)

def list_relations_without_caching(
self, schema_relation: BaseRelation
) -> List[BaseRelation]:
# there's no database, so use the cache as the database
return self.cache.get_relations(schema_relation.database, schema_relation.schema)

###
# ODBC FUNCTIONS -- these should not need to change for every adapter,
# although some adapters may override them
###
def create_schema(self, relation: BaseRelation):
# there's no database, this happens implicitly by adding a relation to the cache
pass

def drop_schema(self, relation: BaseRelation):
for each_relation in self.cache.get_relations(relation.database, relation.schema):
self.cache_dropped(each_relation)

@classmethod
def quote(cls, identifier: str) -> str:
quote_char = ""
return f"{quote_char}{identifier}{quote_char}"

###
# Conversions: These must be implemented by concrete implementations, for
# converting agate types into their sql equivalents.
###
@classmethod
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "str"

@classmethod
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "float"

@classmethod
def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "bool"

@classmethod
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "datetime"

@classmethod
def convert_date_type(cls, *args, **kwargs):
return "date"

@classmethod
def convert_time_type(cls, *args, **kwargs):
return "time"

return BaseAdapterStub(config, get_context("spawn"))


@pytest.fixture
def config(flags) -> AdapterRequiredConfig:
raw_config = {
"credentials": CredentialsStub("test_database", "test_schema"),
"profile_name": "test_profile",
"target_name": "test_target",
"threads": 4,
"project_name": "test_project",
"query_comment": QueryComment(),
"cli_vars": {},
"target_path": "path/to/nowhere",
"log_cache_events": False,
"flags": flags,
}
return SimpleNamespace(**raw_config)


@pytest.fixture
def flags() -> Dict[str, Any]:
# this is the flags collection in dbt_project.yaml
return {}


@pytest.fixture
def behavior_flags() -> List[BehaviorFlag]:
# this is the collection of behavior flags for a specific adapter
return []
58 changes: 58 additions & 0 deletions tests/unit/fixtures/connection_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from contextlib import contextmanager
from typing import ContextManager, List, Optional, Tuple

import agate

from dbt.adapters.base.connections import BaseConnectionManager
from dbt.adapters.contracts.connection import AdapterResponse, Connection, ConnectionState


class ConnectionManagerStub(BaseConnectionManager):
"""
A stub for a connection manager that does not connect to a database
"""

raised_exceptions: List[Exception]

@contextmanager
def exception_handler(self, sql: str) -> ContextManager: # type: ignore
# catch all exceptions and put them on this class for inspection in tests
try:
yield
except Exception as exc:
self.raised_exceptions.append(exc)
finally:
pass

def cancel_open(self) -> Optional[List[str]]:
names = []
for connection in self.thread_connections.values():
if connection.state == ConnectionState.OPEN:
connection.state = ConnectionState.CLOSED
if name := connection.name:
names.append(name)
return names

@classmethod
def open(cls, connection: Connection) -> Connection:
# there's no database, so just change the state
connection.state = ConnectionState.OPEN
return connection

def begin(self) -> None:
# there's no database, so there are no transactions
pass

def commit(self) -> None:
# there's no database, so there are no transactions
pass

def execute(
self,
sql: str,
auto_begin: bool = False,
fetch: bool = False,
limit: Optional[int] = None,
) -> Tuple[AdapterResponse, agate.Table]:
# there's no database, so just return the sql
return AdapterResponse(_message="", code=sql), agate.Table([])
13 changes: 13 additions & 0 deletions tests/unit/fixtures/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dbt.adapters.contracts.connection import Credentials


class CredentialsStub(Credentials):
"""
A stub for a database credentials that does not connect to a database
"""

def type(self) -> str:
return "test"

def _connection_keys(self):
return {"database": self.database, "schema": self.schema}
42 changes: 42 additions & 0 deletions tests/unit/test_behavior_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Any, Dict, List

from dbt_common.behavior_flags import BehaviorFlag
from dbt_common.exceptions import DbtBaseException
import pytest


@pytest.fixture
def flags() -> Dict[str, Any]:
return {
"unregistered_flag": True,
"default_false_user_false_flag": False,
"default_false_user_true_flag": True,
"default_true_user_false_flag": False,
"default_true_user_true_flag": True,
}


@pytest.fixture
def behavior_flags() -> List[BehaviorFlag]:
return [
{"name": "default_false_user_false_flag", "default": False},
{"name": "default_false_user_true_flag", "default": False},
{"name": "default_false_user_skip_flag", "default": False},
{"name": "default_true_user_false_flag", "default": True},
{"name": "default_true_user_true_flag", "default": True},
{"name": "default_true_user_skip_flag", "default": True},
]


def test_register_behavior_flags(adapter):
# make sure that users cannot add arbitrary flags to this collection
with pytest.raises(DbtBaseException):
assert adapter.behavior.unregistered_flag

# check the values of the valid behavior flags
assert not adapter.behavior.default_false_user_false_flag
assert adapter.behavior.default_false_user_true_flag
assert not adapter.behavior.default_false_user_skip_flag
assert not adapter.behavior.default_true_user_false_flag
assert adapter.behavior.default_true_user_true_flag
assert adapter.behavior.default_true_user_skip_flag

0 comments on commit 8a570fa

Please sign in to comment.